diff --git a/CMakeLists.txt b/CMakeLists.txt index 67e0c0104a..c7dd8050a9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -13,6 +13,7 @@ if (APPLE) endif () # Options +option(BUILD_LUA "Build Valkey Lua scripting engine" ON) option(BUILD_UNIT_TESTS "Build valkey-unit-tests" OFF) option(BUILD_TEST_MODULES "Build all test modules" OFF) option(BUILD_EXAMPLE_MODULES "Build example modules" OFF) @@ -33,6 +34,7 @@ add_subdirectory(tests) include(Packaging) # Clear cached variables from the cache +unset(BUILD_LUA CACHE) unset(BUILD_TESTS CACHE) unset(CLANGPP CACHE) unset(CLANG CACHE) diff --git a/README.md b/README.md index a2463494a9..9e2e5e2e1c 100644 --- a/README.md +++ b/README.md @@ -52,13 +52,17 @@ as libsystemd-dev on Debian/Ubuntu or systemd-devel on CentOS) and run: % make USE_SYSTEMD=yes -Since Valkey version 8.1, `fast_float` has been introduced as an optional -dependency, which can speed up sorted sets and other commands that use -the double datatype. To build with `fast_float` support, you'll need a +Since Valkey version 8.1, `fast_float` has been introduced as an optional +dependency, which can speed up sorted sets and other commands that use +the double datatype. To build with `fast_float` support, you'll need a C++ compiler and run: % make USE_FAST_FLOAT=yes +To build Valkey without the Lua engine: + + % make BUILD_LUA=no + To append a suffix to Valkey program names, use: % make PROG_SUFFIX="-alt" diff --git a/cmake/Modules/SourceFiles.cmake b/cmake/Modules/SourceFiles.cmake index edc8d66686..347908e8d1 100644 --- a/cmake/Modules/SourceFiles.cmake +++ b/cmake/Modules/SourceFiles.cmake @@ -100,13 +100,9 @@ set(VALKEY_SERVER_SRCS ${CMAKE_SOURCE_DIR}/src/mt19937-64.c ${CMAKE_SOURCE_DIR}/src/resp_parser.c ${CMAKE_SOURCE_DIR}/src/call_reply.c - ${CMAKE_SOURCE_DIR}/src/lua/script_lua.c ${CMAKE_SOURCE_DIR}/src/script.c ${CMAKE_SOURCE_DIR}/src/functions.c ${CMAKE_SOURCE_DIR}/src/scripting_engine.c - ${CMAKE_SOURCE_DIR}/src/lua/function_lua.c - ${CMAKE_SOURCE_DIR}/src/lua/engine_lua.c - ${CMAKE_SOURCE_DIR}/src/lua/debug_lua.c ${CMAKE_SOURCE_DIR}/src/trace/trace.c ${CMAKE_SOURCE_DIR}/src/trace/trace_rdb.c ${CMAKE_SOURCE_DIR}/src/trace/trace_aof.c diff --git a/cmake/Modules/ValkeySetup.cmake b/cmake/Modules/ValkeySetup.cmake index bcce2bf1f1..5cdb542f86 100644 --- a/cmake/Modules/ValkeySetup.cmake +++ b/cmake/Modules/ValkeySetup.cmake @@ -278,8 +278,8 @@ if (BUILD_SANITIZER) endif () include_directories("${CMAKE_SOURCE_DIR}/deps/libvalkey/include") +include_directories("${CMAKE_SOURCE_DIR}/src/modules/lua") include_directories("${CMAKE_SOURCE_DIR}/deps/linenoise") -include_directories("${CMAKE_SOURCE_DIR}/deps/lua/src") include_directories("${CMAKE_SOURCE_DIR}/deps/hdr_histogram") include_directories("${CMAKE_SOURCE_DIR}/deps/fpconv") @@ -293,6 +293,10 @@ endif () # Common compiler flags add_valkey_server_compiler_options("-pedantic") +if (NOT BUILD_LUA) + message(STATUS "Lua scripting engine is disabled") +endif() + # ---------------------------------------------------- # Build options (allocator, tls, rdma et al) - end # ---------------------------------------------------- diff --git a/deps/Makefile b/deps/Makefile index 7815fe1ec2..d46a101e9c 100644 --- a/deps/Makefile +++ b/deps/Makefile @@ -76,7 +76,7 @@ hdr_histogram: .make-prerequisites fpconv: .make-prerequisites @printf '%b %b\n' $(MAKECOLOR)MAKE$(ENDCOLOR) $(BINCOLOR)$@$(ENDCOLOR) - cd fpconv && $(MAKE) + cd fpconv && $(MAKE) CFLAGS="-fPIC $(CFLAGS)" .PHONY: fpconv @@ -85,12 +85,12 @@ ifeq ($(uname_S),SunOS) LUA_CFLAGS= -D__C99FEATURES__=1 endif -LUA_CFLAGS+= -Wall -DLUA_ANSI -DENABLE_CJSON_GLOBAL -DLUA_USE_MKSTEMP $(CFLAGS) +LUA_CFLAGS+= -Wall -DLUA_ANSI -DENABLE_CJSON_GLOBAL -DLUA_USE_MKSTEMP $(CFLAGS) -fPIC LUA_LDFLAGS+= $(LDFLAGS) ifeq ($(LUA_DEBUG),yes) LUA_CFLAGS+= -O0 -g -DLUA_USE_APICHECK else - LUA_CFLAGS+= -O2 + LUA_CFLAGS+= -O2 endif ifeq ($(LUA_COVERAGE),yes) LUA_CFLAGS += -fprofile-arcs -ftest-coverage diff --git a/deps/lua/CMakeLists.txt b/deps/lua/CMakeLists.txt index 0629d7f978..6a1396d66a 100644 --- a/deps/lua/CMakeLists.txt +++ b/deps/lua/CMakeLists.txt @@ -44,6 +44,7 @@ set(LUA_SRCS add_library(lualib STATIC "${LUA_SRCS}") target_include_directories(lualib PUBLIC "${LUA_SRC_DIR}") target_compile_definitions(lualib PRIVATE ENABLE_CJSON_GLOBAL) +set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fPIC") # Use mkstemp if available check_function_exists(mkstemp HAVE_MKSTEMP) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index e4a81f6e44..2a60cb3758 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -9,7 +9,6 @@ message(STATUS "CFLAGS: ${CMAKE_C_FLAGS}") get_valkey_server_linker_option(VALKEY_SERVER_LDFLAGS) list(APPEND SERVER_LIBS "fpconv") -list(APPEND SERVER_LIBS "lualib") list(APPEND SERVER_LIBS "hdr_histogram") valkey_build_and_install_bin(valkey-server "${VALKEY_SERVER_SRCS}" "${VALKEY_SERVER_LDFLAGS}" "${SERVER_LIBS}" "redis-server") @@ -17,6 +16,23 @@ add_dependencies(valkey-server generate_commands_def) add_dependencies(valkey-server generate_fmtargs_h) add_dependencies(valkey-server release_header) +if (BUILD_LUA) + message(STATUS "Build Lua scripting engine module") + add_subdirectory(modules/lua) + add_dependencies(valkey-server valkeylua) + target_compile_definitions(valkey-server PRIVATE LUA_ENGINE_ENABLED) + target_compile_definitions(valkey-server PRIVATE LUA_ENGINE_LIB=libvalkeylua.so) + target_link_options(valkey-server PRIVATE -Wl,--disable-new-dtags) + + set(VALKEY_INSTALL_RPATH "") + set_target_properties(valkey-server PROPERTIES + INSTALL_RPATH "${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_LIBDIR};${CMAKE_LIBRARY_OUTPUT_DIRECTORY}" + INSTALL_RPATH_USE_LINK_PATH TRUE + BUILD_WITH_INSTALL_RPATH TRUE + ) +endif() +unset(BUILD_LUA CACHE) + if (VALKEY_RELEASE_BUILD) # Enable LTO for Release build set_property(TARGET valkey-server PROPERTY INTERPROCEDURAL_OPTIMIZATION TRUE) diff --git a/src/Makefile b/src/Makefile index 1ce9281678..06f20fa532 100644 --- a/src/Makefile +++ b/src/Makefile @@ -31,7 +31,7 @@ endif ifneq ($(OPTIMIZATION),-O0) OPTIMIZATION+=-fno-omit-frame-pointer endif -DEPENDENCY_TARGETS=libvalkey linenoise lua hdr_histogram fpconv +DEPENDENCY_TARGETS=libvalkey linenoise hdr_histogram fpconv NODEPS:=clean distclean # Default settings @@ -63,6 +63,7 @@ endif PREFIX?=/usr/local INSTALL_BIN=$(PREFIX)/bin +INSTALL_LIB=$(PREFIX)/lib INSTALL=install PKG_CONFIG?=pkg-config @@ -250,7 +251,25 @@ ifdef OPENSSL_PREFIX endif # Include paths to dependencies -FINAL_CFLAGS+= -I../deps/libvalkey/include -I../deps/linenoise -I../deps/lua/src -I../deps/hdr_histogram -I../deps/fpconv +FINAL_CFLAGS+= -I../deps/libvalkey/include -I../deps/linenoise -I../deps/hdr_histogram -I../deps/fpconv + +# Lua scripting engine module +LUA_MODULE_NAME:=modules/lua/libvalkeylua.so +ifeq ($(BUILD_LUA),no) + LUA_MODULE= + LUA_MODULE_INSTALL= +else + LUA_MODULE=$(LUA_MODULE_NAME) + LUA_MODULE_INSTALL=install-lua-module + + current_dir = $(shell pwd) + FINAL_CFLAGS+=-DLUA_ENGINE_ENABLED -DLUA_ENGINE_LIB=libvalkeylua.so +ifeq ($(uname_S),Darwin) + FINAL_LDFLAGS+= -Wl,-rpath,$(PREFIX)/lib:$(current_dir)/modules/lua +else + FINAL_LDFLAGS+= -Wl,-rpath,$(PREFIX)/lib:$(current_dir)/modules/lua -Wl,--disable-new-dtags +endif +endif # Determine systemd support and/or build preference (defaulting to auto-detection) BUILD_WITH_SYSTEMD=no @@ -423,7 +442,7 @@ ENGINE_NAME=valkey SERVER_NAME=$(ENGINE_NAME)-server$(PROG_SUFFIX) ENGINE_SENTINEL_NAME=$(ENGINE_NAME)-sentinel$(PROG_SUFFIX) ENGINE_TRACE_OBJ=trace/trace.o trace/trace_commands.o trace/trace_db.o trace/trace_cluster.o trace/trace_server.o trace/trace_rdb.o trace/trace_aof.o -ENGINE_SERVER_OBJ=threads_mngr.o adlist.o vector.o quicklist.o ae.o anet.o dict.o hashtable.o kvstore.o server.o sds.o zmalloc.o lzf_c.o lzf_d.o pqsort.o zipmap.o sha1.o ziplist.o release.o memory_prefetch.o io_threads.o networking.o util.o object.o db.o replication.o rdb.o t_string.o t_list.o t_set.o t_zset.o t_hash.o config.o aof.o pubsub.o multi.o debug.o sort.o intset.o syncio.o cluster.o cluster_legacy.o cluster_slot_stats.o crc16.o cluster_migrateslots.o endianconv.o commandlog.o eval.o bio.o rio.o rand.o memtest.o syscheck.o crcspeed.o crccombine.o crc64.o bitops.o sentinel.o notify.o setproctitle.o blocked.o hyperloglog.o latency.o sparkline.o valkey-check-rdb.o valkey-check-aof.o geo.o lazyfree.o module.o evict.o expire.o geohash.o geohash_helper.o childinfo.o allocator_defrag.o defrag.o siphash.o rax.o t_stream.o listpack.o localtime.o lolwut.o lolwut5.o lolwut6.o lolwut9.o acl.o tracking.o socket.o tls.o sha256.o timeout.o setcpuaffinity.o monotonic.o mt19937-64.o resp_parser.o call_reply.o script.o functions.o commands.o strl.o connection.o unix.o logreqres.o rdma.o scripting_engine.o entry.o vset.o lua/script_lua.o lua/function_lua.o lua/engine_lua.o lua/debug_lua.o +ENGINE_SERVER_OBJ=threads_mngr.o adlist.o vector.o quicklist.o ae.o anet.o dict.o hashtable.o kvstore.o server.o sds.o zmalloc.o lzf_c.o lzf_d.o pqsort.o zipmap.o sha1.o ziplist.o release.o memory_prefetch.o io_threads.o networking.o util.o object.o db.o replication.o rdb.o t_string.o t_list.o t_set.o t_zset.o t_hash.o config.o aof.o pubsub.o multi.o debug.o sort.o intset.o syncio.o cluster.o cluster_legacy.o cluster_slot_stats.o crc16.o cluster_migrateslots.o endianconv.o commandlog.o eval.o bio.o rio.o rand.o memtest.o syscheck.o crcspeed.o crccombine.o crc64.o bitops.o sentinel.o notify.o setproctitle.o blocked.o hyperloglog.o latency.o sparkline.o valkey-check-rdb.o valkey-check-aof.o geo.o lazyfree.o module.o evict.o expire.o geohash.o geohash_helper.o childinfo.o allocator_defrag.o defrag.o siphash.o rax.o t_stream.o listpack.o localtime.o lolwut.o lolwut5.o lolwut6.o lolwut9.o acl.o tracking.o socket.o tls.o sha256.o timeout.o setcpuaffinity.o monotonic.o mt19937-64.o resp_parser.o call_reply.o script.o functions.o commands.o strl.o connection.o unix.o logreqres.o rdma.o scripting_engine.o entry.o vset.o ENGINE_SERVER_OBJ+=$(ENGINE_TRACE_OBJ) ENGINE_CLI_NAME=$(ENGINE_NAME)-cli$(PROG_SUFFIX) ENGINE_CLI_OBJ=anet.o adlist.o dict.o valkey-cli.o zmalloc.o release.o ae.o serverassert.o crcspeed.o crccombine.o crc64.o siphash.o crc16.o monotonic.o cli_common.o mt19937-64.o strl.o cli_commands.o sds.o util.o sha256.o @@ -448,7 +467,7 @@ ifeq ($(USE_FAST_FLOAT),yes) FINAL_LIBS += $(FAST_FLOAT_STRTOD_OBJECT) endif -all: $(SERVER_NAME) $(ENGINE_SENTINEL_NAME) $(ENGINE_CLI_NAME) $(ENGINE_BENCHMARK_NAME) $(ENGINE_CHECK_RDB_NAME) $(ENGINE_CHECK_AOF_NAME) $(TLS_MODULE) $(RDMA_MODULE) +all: $(SERVER_NAME) $(ENGINE_SENTINEL_NAME) $(ENGINE_CLI_NAME) $(ENGINE_BENCHMARK_NAME) $(ENGINE_CHECK_RDB_NAME) $(ENGINE_CHECK_AOF_NAME) $(TLS_MODULE) $(RDMA_MODULE) $(LUA_MODULE) @echo "" @echo "Hint: It's a good idea to run 'make test' ;)" @echo "" @@ -473,6 +492,7 @@ persist-settings: distclean echo BUILD_TLS=$(BUILD_TLS) >> .make-settings echo BUILD_RDMA=$(BUILD_RDMA) >> .make-settings echo USE_SYSTEMD=$(USE_SYSTEMD) >> .make-settings + echo BUILD_LUA=$(BUILD_LUA) >> .make-settings echo CFLAGS=$(CFLAGS) >> .make-settings echo LDFLAGS=$(LDFLAGS) >> .make-settings echo SERVER_CFLAGS=$(SERVER_CFLAGS) >> .make-settings @@ -498,7 +518,7 @@ endif # valkey-server $(SERVER_NAME): $(ENGINE_SERVER_OBJ) - $(SERVER_LD) -o $@ $^ ../deps/libvalkey/lib/libvalkey.a ../deps/lua/src/liblua.a ../deps/hdr_histogram/libhdrhistogram.a ../deps/fpconv/libfpconv.a $(FINAL_LIBS) + $(SERVER_LD) -o $@ $^ ../deps/libvalkey/lib/libvalkey.a ../deps/hdr_histogram/libhdrhistogram.a ../deps/fpconv/libfpconv.a $(FINAL_LIBS) # Valkey static library, used to compile against for unit testing $(ENGINE_LIB_NAME): $(ENGINE_SERVER_OBJ) @@ -506,7 +526,7 @@ $(ENGINE_LIB_NAME): $(ENGINE_SERVER_OBJ) # valkey-unit-tests $(ENGINE_UNIT_TESTS): $(ENGINE_TEST_OBJ) $(ENGINE_LIB_NAME) - $(SERVER_LD) -o $@ $^ ../deps/libvalkey/lib/libvalkey.a ../deps/lua/src/liblua.a ../deps/hdr_histogram/libhdrhistogram.a ../deps/fpconv/libfpconv.a $(FINAL_LIBS) + $(SERVER_LD) -o $@ $^ ../deps/libvalkey/lib/libvalkey.a ../deps/hdr_histogram/libhdrhistogram.a ../deps/fpconv/libfpconv.a $(FINAL_LIBS) # valkey-sentinel $(ENGINE_SENTINEL_NAME): $(SERVER_NAME) @@ -528,6 +548,10 @@ $(TLS_MODULE_NAME): $(SERVER_NAME) $(RDMA_MODULE_NAME): $(SERVER_NAME) $(QUIET_CC)$(CC) -o $@ rdma.c -shared -fPIC $(RDMA_MODULE_CFLAGS) +# engine_lua.so +$(LUA_MODULE_NAME): $(SERVER_NAME) + cd modules/lua && $(MAKE) OPTIMIZATION="$(OPTIMIZATION)" + # valkey-cli $(ENGINE_CLI_NAME): $(ENGINE_CLI_OBJ) $(SERVER_LD) -o $@ $^ ../deps/libvalkey/lib/libvalkey.a ../deps/linenoise/linenoise.o ../deps/fpconv/libfpconv.a $(FINAL_LIBS) $(TLS_CLIENT_LIBS) $(RDMA_CLIENT_LIBS) @@ -545,9 +569,6 @@ DEP = $(ENGINE_SERVER_OBJ:%.o=%.d) $(ENGINE_CLI_OBJ:%.o=%.d) $(ENGINE_BENCHMARK_ %.o: %.c .make-prerequisites $(SERVER_CC) -MMD -o $@ -c $< -lua/%.o: lua/%.c .make-prerequisites - $(SERVER_CC) -MMD -o $@ -c $< - trace/%.o: trace/%.c .make-prerequisites $(SERVER_CC) -Itrace -MMD -o $@ -c $< @@ -574,8 +595,9 @@ endif commands.c: $(COMMANDS_DEF_FILENAME).def clean: - rm -rf $(SERVER_NAME) $(ENGINE_SENTINEL_NAME) $(ENGINE_CLI_NAME) $(ENGINE_BENCHMARK_NAME) $(ENGINE_CHECK_RDB_NAME) $(ENGINE_CHECK_AOF_NAME) $(ENGINE_UNIT_TESTS) $(ENGINE_LIB_NAME) unit/*.o unit/*.d lua/*.o lua/*.d trace/*.o trace/*.d *.o *.gcda *.gcno *.gcov valkey.info lcov-html Makefile.dep *.so + rm -rf $(SERVER_NAME) $(ENGINE_SENTINEL_NAME) $(ENGINE_CLI_NAME) $(ENGINE_BENCHMARK_NAME) $(ENGINE_CHECK_RDB_NAME) $(ENGINE_CHECK_AOF_NAME) $(ENGINE_UNIT_TESTS) $(ENGINE_LIB_NAME) unit/*.o unit/*.d trace/*.o trace/*.d *.o *.gcda *.gcno *.gcov valkey.info lcov-html Makefile.dep *.so rm -f $(DEP) + -(cd modules/lua && $(MAKE) clean) .PHONY: clean @@ -634,7 +656,7 @@ valgrind: helgrind: $(MAKE) OPTIMIZATION="-O0" MALLOC="libc" CFLAGS="-D__ATOMIC_VAR_FORCE_SYNC_MACROS" SERVER_CFLAGS="-I/usr/local/include" SERVER_LDFLAGS="-L/usr/local/lib" -install: all +install: all $(LUA_MODULE_INSTALL) @mkdir -p $(INSTALL_BIN) $(call MAKE_INSTALL,$(SERVER_NAME),$(INSTALL_BIN)) $(call MAKE_INSTALL,$(ENGINE_BENCHMARK_NAME),$(INSTALL_BIN)) @@ -649,6 +671,10 @@ install: all $(call MAYBE_INSTALL_REDIS_SYMLINK,$(ENGINE_CHECK_AOF_NAME),$(INSTALL_BIN)) $(call MAYBE_INSTALL_REDIS_SYMLINK,$(ENGINE_SENTINEL_NAME),$(INSTALL_BIN)) +install-lua-module: $(LUA_MODULE) + @mkdir -p $(INSTALL_LIB) + $(call MAKE_INSTALL,$(LUA_MODULE),$(INSTALL_LIB)) + uninstall: @rm -f $(INSTALL_BIN)/{$(SERVER_NAME),$(ENGINE_BENCHMARK_NAME),$(ENGINE_CLI_NAME),$(ENGINE_CHECK_RDB_NAME),$(ENGINE_CHECK_AOF_NAME),$(ENGINE_SENTINEL_NAME)} $(call MAYBE_UNINSTALL_REDIS_SYMLINK,$(INSTALL_BIN),$(SERVER_NAME)) diff --git a/src/config.c b/src/config.c index 58b3324fa4..0c2ae6198c 100644 --- a/src/config.c +++ b/src/config.c @@ -29,6 +29,7 @@ */ #include "io_threads.h" +#include "sds.h" #include "server.h" #include "cluster.h" #include "connection.h" diff --git a/src/eval.c b/src/eval.c index bbfe080f94..c85273bb07 100644 --- a/src/eval.c +++ b/src/eval.c @@ -240,7 +240,7 @@ int evalExtractShebangFlags(sds body, } if (out_engine) { - uint32_t engine_name_len = sdslen(parts[0]) - 2; + size_t engine_name_len = sdslen(parts[0]) - 2; *out_engine = zcalloc(engine_name_len + 1); valkey_strlcpy(*out_engine, parts[0] + 2, engine_name_len + 1); } @@ -442,7 +442,7 @@ static int evalRegisterNewScript(client *c, robj *body, char **sha) { } es->body = body; int retval = dictAdd(evalCtx.scripts, _sha, es); - serverAssertWithInfo(c ? c : scriptingEngineGetClient(engine), NULL, retval == DICT_OK); + serverAssert(retval == DICT_OK); evalCtx.scripts_mem += sdsAllocSize(_sha) + getStringObjectSdsUsedMemory(body); incrRefCount(body); zfree(functions); diff --git a/src/lua/engine_lua.c b/src/lua/engine_lua.c deleted file mode 100644 index 5951ceaf38..0000000000 --- a/src/lua/engine_lua.c +++ /dev/null @@ -1,435 +0,0 @@ -/* - * Copyright (c) Valkey Contributors - * All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - */ -#include "engine_lua.h" -#include "function_lua.h" -#include "script_lua.h" -#include "debug_lua.h" - -#define LUA_ENGINE_NAME "LUA" -#define REGISTRY_ERROR_HANDLER_NAME "__ERROR_HANDLER__" - -typedef struct luaEngineCtx { - lua_State *eval_lua; /* The Lua interpreter for EVAL commands. We use just one for all EVAL calls */ - lua_State *function_lua; /* The Lua interpreter for FCALL commands. We use just one for all FCALL calls */ -} luaEngineCtx; - -/* Adds server.debug() function used by lua debugger - * - * Log a string message into the output console. - * Can take multiple arguments that will be separated by commas. - * Nothing is returned to the caller. */ -static int luaServerDebugCommand(lua_State *lua) { - if (!ldbIsActive()) return 0; - int argc = lua_gettop(lua); - sds log = sdscatprintf(sdsempty(), " line %d: ", ldbGetCurrentLine()); - while (argc--) { - log = ldbCatStackValue(log, lua, -1 - argc); - if (argc != 0) log = sdscatlen(log, ", ", 2); - } - ldbLog(log); - return 0; -} - -/* Adds server.breakpoint() function used by lua debugger. - * - * Allows to stop execution during a debugging session from within - * the Lua code implementation, like if a breakpoint was set in the code - * immediately after the function. */ -static int luaServerBreakpointCommand(lua_State *lua) { - if (ldbIsActive()) { - ldbSetBreakpointOnNextLine(1); - lua_pushboolean(lua, 1); - } else { - lua_pushboolean(lua, 0); - } - return 1; -} - - -/* Adds server.replicate_commands() - * - * DEPRECATED: Now do nothing and always return true. - * Turn on single commands replication if the script never called - * a write command so far, and returns true. Otherwise if the script - * already started to write, returns false and stick to whole scripts - * replication, which is our default. */ -int luaServerReplicateCommandsCommand(lua_State *lua) { - lua_pushboolean(lua, 1); - return 1; -} - -static void luaStateInstallErrorHandler(lua_State *lua) { - /* Add a helper function we use for pcall error reporting. - * Note that when the error is in the C function we want to report the - * information about the caller, that's what makes sense from the point - * of view of the user debugging a script. */ - lua_pushstring(lua, REGISTRY_ERROR_HANDLER_NAME); - char *errh_func = "local dbg = debug\n" - "debug = nil\n" - "local error_handler = function (err)\n" - " local i = dbg.getinfo(2,'nSl')\n" - " if i and i.what == 'C' then\n" - " i = dbg.getinfo(3,'nSl')\n" - " end\n" - " if type(err) ~= 'table' then\n" - " err = {err='ERR ' .. tostring(err)}" - " end" - " if i then\n" - " err['source'] = i.source\n" - " err['line'] = i.currentline\n" - " end" - " return err\n" - "end\n" - "return error_handler"; - luaL_loadbuffer(lua, errh_func, strlen(errh_func), "@err_handler_def"); - lua_pcall(lua, 0, 1, 0); - lua_settable(lua, LUA_REGISTRYINDEX); -} - -static void luaStateLockGlobalTable(lua_State *lua) { - /* Lock the global table from any changes */ - lua_pushvalue(lua, LUA_GLOBALSINDEX); - luaSetErrorMetatable(lua); - /* Recursively lock all tables that can be reached from the global table */ - luaSetTableProtectionRecursively(lua); - lua_pop(lua, 1); - /* Set metatables of basic types (string, number, nil etc.) readonly. */ - luaSetTableProtectionForBasicTypes(lua); -} - - -static void initializeEvalLuaState(lua_State *lua) { - /* register debug commands. we only need to add it under 'server' as 'redis' - * is effectively aliased to 'server' table at this point. */ - lua_getglobal(lua, "server"); - - /* server.breakpoint */ - lua_pushstring(lua, "breakpoint"); - lua_pushcfunction(lua, luaServerBreakpointCommand); - lua_settable(lua, -3); - - /* server.debug */ - lua_pushstring(lua, "debug"); - lua_pushcfunction(lua, luaServerDebugCommand); - lua_settable(lua, -3); - - /* server.replicate_commands */ - lua_pushstring(lua, "replicate_commands"); - lua_pushcfunction(lua, luaServerReplicateCommandsCommand); - lua_settable(lua, -3); - - lua_setglobal(lua, "server"); - - /* Duplicate the function with __server__err__hanler and - * __redis__err_handler name for backwards compatibility. */ - lua_pushstring(lua, REGISTRY_ERROR_HANDLER_NAME); - lua_gettable(lua, LUA_REGISTRYINDEX); - lua_setglobal(lua, "__server__err__handler"); - lua_getglobal(lua, "__server__err__handler"); - lua_setglobal(lua, "__redis__err__handler"); -} - -static void initializeLuaState(luaEngineCtx *lua_engine_ctx, - subsystemType type) { - lua_State *lua = lua_open(); - - if (type == VMSE_EVAL) { - lua_engine_ctx->eval_lua = lua; - } else { - serverAssert(type == VMSE_FUNCTION); - lua_engine_ctx->function_lua = lua; - } - - luaRegisterServerAPI(lua); - luaStateInstallErrorHandler(lua); - - if (type == VMSE_EVAL) { - initializeEvalLuaState(lua); - luaStateLockGlobalTable(lua); - } else { - luaStateLockGlobalTable(lua); - luaFunctionInitializeLuaState(lua); - } -} - -static struct luaEngineCtx *createEngineContext(void) { - luaEngineCtx *lua_engine_ctx = zmalloc(sizeof(*lua_engine_ctx)); - - initializeLuaState(lua_engine_ctx, VMSE_EVAL); - initializeLuaState(lua_engine_ctx, VMSE_FUNCTION); - - return lua_engine_ctx; -} - -static engineMemoryInfo luaEngineGetMemoryInfo(ValkeyModuleCtx *module_ctx, - engineCtx *engine_ctx, - subsystemType type) { - /* The lua engine is implemented in the core, and not in a Valkey Module */ - serverAssert(module_ctx == NULL); - - luaEngineCtx *lua_engine_ctx = engine_ctx; - engineMemoryInfo mem_info = {0}; - - if (type == VMSE_EVAL || type == VMSE_ALL) { - mem_info.used_memory += luaMemory(lua_engine_ctx->eval_lua); - } - if (type == VMSE_FUNCTION || type == VMSE_ALL) { - mem_info.used_memory += luaMemory(lua_engine_ctx->function_lua); - } - - mem_info.engine_memory_overhead = zmalloc_size(engine_ctx); - - return mem_info; -} - -static compiledFunction **luaEngineCompileCode(ValkeyModuleCtx *module_ctx, - engineCtx *engine_ctx, - subsystemType type, - const char *code, - size_t code_len, - size_t timeout, - size_t *out_num_compiled_functions, - robj **err) { - /* The lua engine is implemented in the core, and not in a Valkey Module */ - serverAssert(module_ctx == NULL); - - luaEngineCtx *lua_engine_ctx = (luaEngineCtx *)engine_ctx; - compiledFunction **functions = NULL; - - if (type == VMSE_EVAL) { - lua_State *lua = lua_engine_ctx->eval_lua; - - if (luaL_loadbuffer(lua, code, code_len, "@user_script")) { - sds error = sdscatfmt(sdsempty(), "Error compiling script (new function): %s", lua_tostring(lua, -1)); - *err = createObject(OBJ_STRING, error); - lua_pop(lua, 1); - return functions; - } - - serverAssert(lua_isfunction(lua, -1)); - int function_ref = luaL_ref(lua, LUA_REGISTRYINDEX); - - luaFunction *script = zcalloc(sizeof(luaFunction)); - *script = (luaFunction){ - .lua = lua, - .function_ref = function_ref, - }; - - compiledFunction *func = zcalloc(sizeof(*func)); - *func = (compiledFunction){ - .name = NULL, - .function = script, - .desc = NULL, - .f_flags = 0}; - - *out_num_compiled_functions = 1; - functions = zcalloc(sizeof(compiledFunction *)); - *functions = func; - } else { - functions = luaFunctionLibraryCreate(lua_engine_ctx->function_lua, - code, - timeout, - out_num_compiled_functions, - err); - } - - return functions; -} - -static void luaEngineFunctionCall(ValkeyModuleCtx *module_ctx, - engineCtx *engine_ctx, - serverRuntimeCtx *server_ctx, - compiledFunction *compiled_function, - subsystemType type, - robj **keys, - size_t nkeys, - robj **args, - size_t nargs) { - /* The lua engine is implemented in the core, and not in a Valkey Module */ - serverAssert(module_ctx == NULL); - - luaEngineCtx *lua_engine_ctx = (luaEngineCtx *)engine_ctx; - lua_State *lua = type == VMSE_EVAL ? lua_engine_ctx->eval_lua : lua_engine_ctx->function_lua; - luaFunction *script = compiled_function->function; - int lua_function_ref = script->function_ref; - - /* Push the pcall error handler function on the stack. */ - lua_pushstring(lua, REGISTRY_ERROR_HANDLER_NAME); - lua_gettable(lua, LUA_REGISTRYINDEX); - - lua_rawgeti(lua, LUA_REGISTRYINDEX, lua_function_ref); - serverAssert(!lua_isnil(lua, -1)); - - luaCallFunction(server_ctx, - lua, - keys, - nkeys, - args, - nargs, - type == VMSE_EVAL ? ldbIsActive() : 0); - - lua_pop(lua, 1); /* Remove the error handler. */ -} - -static void resetLuaContext(void *context) { - lua_State *lua = context; - lua_gc(lua, LUA_GCCOLLECT, 0); - lua_close(lua); - -#if !defined(USE_LIBC) - /* The lua interpreter may hold a lot of memory internally, and lua is - * using libc. libc may take a bit longer to return the memory to the OS, - * so after lua_close, we call malloc_trim try to purge it earlier. - * - * We do that only when the server itself does not use libc. When Lua and the server - * use different allocators, one won't use the fragmentation holes of the - * other, and released memory can take a long time until it is returned to - * the OS. */ - zlibc_trim(); -#endif -} - -static callableLazyEnvReset *luaEngineResetEvalEnv(ValkeyModuleCtx *module_ctx, - engineCtx *engine_ctx, - subsystemType type, - int async) { - /* The lua engine is implemented in the core, and not in a Valkey Module */ - serverAssert(module_ctx == NULL); - - luaEngineCtx *lua_engine_ctx = (luaEngineCtx *)engine_ctx; - serverAssert(type == VMSE_EVAL || type == VMSE_FUNCTION); - lua_State *lua = type == VMSE_EVAL ? lua_engine_ctx->eval_lua : lua_engine_ctx->function_lua; - serverAssert(lua); - callableLazyEnvReset *callback = NULL; - - if (async) { - callback = zcalloc(sizeof(*callback)); - *callback = (callableLazyEnvReset){ - .context = lua, - .engineLazyEnvResetCallback = resetLuaContext, - }; - } else { - resetLuaContext(lua); - } - - initializeLuaState(lua_engine_ctx, type); - - return callback; -} - -static size_t luaEngineFunctionMemoryOverhead(ValkeyModuleCtx *module_ctx, - compiledFunction *compiled_function) { - /* The lua engine is implemented in the core, and not in a Valkey Module */ - serverAssert(module_ctx == NULL); - - return zmalloc_size(compiled_function->function) + - (compiled_function->name ? zmalloc_size(compiled_function->name) : 0) + - (compiled_function->desc ? zmalloc_size(compiled_function->desc) : 0) + - zmalloc_size(compiled_function); -} - -static void luaEngineFreeFunction(ValkeyModuleCtx *module_ctx, - engineCtx *engine_ctx, - subsystemType type, - compiledFunction *compiled_function) { - /* The lua engine is implemented in the core, and not in a Valkey Module */ - serverAssert(module_ctx == NULL); - serverAssert(type == VMSE_EVAL || type == VMSE_FUNCTION); - - luaEngineCtx *lua_engine_ctx = engine_ctx; - lua_State *lua = type == VMSE_EVAL ? lua_engine_ctx->eval_lua : lua_engine_ctx->function_lua; - serverAssert(lua); - - luaFunction *script = (luaFunction *)compiled_function->function; - if (lua == script->lua) { - /* The lua context is still the same, which means that we're not - * resetting the whole eval context, and therefore, we need to - * delete the function from the lua context. - */ - lua_unref(lua, script->function_ref); - } - zfree(script); - - if (compiled_function->name) { - decrRefCount(compiled_function->name); - } - if (compiled_function->desc) { - decrRefCount(compiled_function->desc); - } - zfree(compiled_function); -} - -static debuggerEnableRet luaEngineDebuggerEnable(ValkeyModuleCtx *module_ctx, - engineCtx *engine_ctx, - subsystemType type, - const debuggerCommand **commands, - size_t *commands_len) { - UNUSED(module_ctx); - - if (type != VMSE_EVAL) { - return VMSE_DEBUG_NOT_SUPPORTED; - } - - ldbEnable(); - - luaEngineCtx *lua_engine_ctx = engine_ctx; - ldbGenerateDebuggerCommandsArray(lua_engine_ctx->eval_lua, - commands, - commands_len); - - return VMSE_DEBUG_ENABLED; -} - -static void luaEngineDebuggerDisable(ValkeyModuleCtx *module_ctx, - engineCtx *engine_ctx, - subsystemType type) { - UNUSED(module_ctx); - UNUSED(engine_ctx); - UNUSED(type); - ldbDisable(); -} - -static void luaEngineDebuggerStart(ValkeyModuleCtx *module_ctx, - engineCtx *engine_ctx, - subsystemType type, - robj *source) { - UNUSED(module_ctx); - UNUSED(engine_ctx); - UNUSED(type); - ldbStart(source); -} - -static void luaEngineDebuggerEnd(ValkeyModuleCtx *module_ctx, - engineCtx *engine_ctx, - subsystemType type) { - UNUSED(module_ctx); - UNUSED(engine_ctx); - UNUSED(type); - ldbEnd(); -} - -int luaEngineInitEngine(void) { - ldbInit(); - - engineMethods methods = { - .version = VALKEYMODULE_SCRIPTING_ENGINE_ABI_VERSION, - .compile_code = luaEngineCompileCode, - .free_function = luaEngineFreeFunction, - .call_function = luaEngineFunctionCall, - .get_function_memory_overhead = luaEngineFunctionMemoryOverhead, - .reset_env = luaEngineResetEvalEnv, - .get_memory_info = luaEngineGetMemoryInfo, - .debugger_enable = luaEngineDebuggerEnable, - .debugger_disable = luaEngineDebuggerDisable, - .debugger_start = luaEngineDebuggerStart, - .debugger_end = luaEngineDebuggerEnd, - }; - - return scriptingEngineManagerRegister(LUA_ENGINE_NAME, - NULL, - createEngineContext(), - &methods); -} diff --git a/src/lua/engine_lua.h b/src/lua/engine_lua.h deleted file mode 100644 index db4ad18c08..0000000000 --- a/src/lua/engine_lua.h +++ /dev/null @@ -1,14 +0,0 @@ -#ifndef _ENGINE_LUA_ -#define _ENGINE_LUA_ - -#include "../scripting_engine.h" -#include - -typedef struct luaFunction { - lua_State *lua; /* Pointer to the lua context where this function was created. Only used in EVAL context. */ - int function_ref; /* Special ID that allows getting the Lua function object from the Lua registry */ -} luaFunction; - -int luaEngineInitEngine(void); - -#endif /* _ENGINE_LUA_ */ diff --git a/src/lua/function_lua.h b/src/lua/function_lua.h deleted file mode 100644 index 6b45cef1df..0000000000 --- a/src/lua/function_lua.h +++ /dev/null @@ -1,16 +0,0 @@ -#ifndef _FUNCTION_LUA_H_ -#define _FUNCTION_LUA_H_ - -#include "engine_lua.h" - -void luaFunctionInitializeLuaState(lua_State *lua); - -compiledFunction **luaFunctionLibraryCreate(lua_State *lua, - const char *code, - size_t timeout, - size_t *out_num_compiled_functions, - robj **err); - -void luaFunctionFreeFunction(lua_State *lua, void *function); - -#endif /* _FUNCTION_LUA_H_ */ diff --git a/src/module.c b/src/module.c index ae5f622780..e5f0692efa 100644 --- a/src/module.c +++ b/src/module.c @@ -899,6 +899,8 @@ void moduleFreeContext(ValkeyModuleCtx *ctx) { moduleReleaseTempClient(ctx->client); else if (ctx->flags & VALKEYMODULE_CTX_NEW_CLIENT) freeClient(ctx->client); + else if (ctx->flags & VALKEYMODULE_CTX_SCRIPT_EXECUTION) + ctx->client = NULL; /* Do not free the client, it was assigned manually. */ } static CallReply *moduleParseReply(client *c, ValkeyModuleCtx *ctx) { @@ -998,9 +1000,29 @@ void moduleCreateContext(ValkeyModuleCtx *out_ctx, ValkeyModule *module, int ctx */ void moduleScriptingEngineInitContext(ValkeyModuleCtx *out_ctx, ValkeyModule *module, + int add_script_execution_flag, + int add_thread_safe_flag, client *client) { - moduleCreateContext(out_ctx, module, VALKEYMODULE_CTX_SCRIPT_EXECUTION); - out_ctx->client = client; + /* The VALKEYMODULE_CTX_SCRIPT_EXECUTION requires a non-NULL client */ + serverAssert(!add_script_execution_flag || client != NULL); + + /* For non-script execution contexts, and non-asynchronous contexts, allocate + * a temporary client so the scripting engine can call server commands in + * its callbacks. */ + int ctx_flags = VALKEYMODULE_CTX_TEMP_CLIENT; + + if (add_script_execution_flag) { + ctx_flags = VALKEYMODULE_CTX_SCRIPT_EXECUTION; + } + if (add_thread_safe_flag) { + ctx_flags = VALKEYMODULE_CTX_THREAD_SAFE; + } + + moduleCreateContext(out_ctx, module, ctx_flags); + + if (add_script_execution_flag) { + out_ctx->client = client; + } } /* This command binds the normal command invocation with commands @@ -6454,11 +6476,17 @@ ValkeyModuleCallReply *VM_Call(ValkeyModuleCtx *ctx, const char *cmdname, const if (flags & VALKEYMODULE_ARGV_RESP_3) { c->resp = 3; } else if (flags & VALKEYMODULE_ARGV_RESP_AUTO) { + serverAssert(ctx->client != NULL); /* Auto mode means to take the same protocol as the ctx client. */ c->resp = ctx->client->resp; } if (ctx->module) ctx->module->in_call++; + if (flags & VALKEYMODULE_ARGV_SCRIPT_MODE && is_running_script) { + c->flag.module = 0; + c->flag.script = 1; + } + user *user = NULL; if (flags & VALKEYMODULE_ARGV_RUN_AS_USER) { user = ctx->user ? ctx->user->user : ctx->client->user; @@ -6502,11 +6530,6 @@ ValkeyModuleCallReply *VM_Call(ValkeyModuleCtx *ctx, const char *cmdname, const cmd_flags = getCommandFlags(c); if (flags & VALKEYMODULE_ARGV_SCRIPT_MODE) { - if (is_running_script) { - c->flag.module = 0; - c->flag.script = 1; - } - /* In script mode, commands with CMD_NOSCRIPT flag are normally forbidden. * However, we allow them if both conditions are met: * 1. We're running in the context of a scripting engine running a script @@ -6632,7 +6655,7 @@ ValkeyModuleCallReply *VM_Call(ValkeyModuleCtx *ctx, const char *cmdname, const * CLIENT PAUSE WRITE. */ if (is_running_script && scriptIsReadOnly() && (cmd_flags & (CMD_WRITE | CMD_MAY_REPLICATE))) { errno = ENOSPC; - reply_error_msg = sdsnew("Write commands are not allowed from read-only scripts"); + reply_error_msg = sdsnew("Write commands are not allowed from read-only scripts."); goto cleanup; } @@ -12744,16 +12767,8 @@ int moduleLoad(const char *path, void **module_argv, int module_argc, int is_loa return C_OK; } -/* Unload the module registered with the specified name. On success - * C_OK is returned, otherwise C_ERR is returned and errmsg is set - * with an appropriate message. */ -int moduleUnload(sds name, const char **errmsg) { - struct ValkeyModule *module = dictFetchValue(modules, name); - - if (module == NULL) { - *errmsg = "no such module with that name"; - return C_ERR; - } else if (listLength(module->types)) { +static int moduleUnloadInternal(struct ValkeyModule *module, const char **errmsg) { + if (listLength(module->types)) { *errmsg = "the module exports one or more module-side data " "types, can't unload"; return C_ERR; @@ -12778,7 +12793,7 @@ int moduleUnload(sds name, const char **errmsg) { onunload = (int (*)(void *))(unsigned long)dlsym(module->handle, onUnloadNames[i]); if (onunload) { if (i != 0) { - serverLog(LL_NOTICE, "Legacy Redis Module %s found", name); + serverLog(LL_NOTICE, "Legacy Redis Module %s found", module->name); } break; } @@ -12791,7 +12806,7 @@ int moduleUnload(sds name, const char **errmsg) { moduleFreeContext(&ctx); if (unload_status == VALKEYMODULE_ERR) { - serverLog(LL_WARNING, "Module %s OnUnload failed. Unload canceled.", name); + serverLog(LL_WARNING, "Module %s OnUnload failed. Unload canceled.", module->name); errno = ECANCELED; return C_ERR; } @@ -12820,6 +12835,49 @@ int moduleUnload(sds name, const char **errmsg) { return C_OK; } +/* Unload the module registered with the specified name. On success + * C_OK is returned, otherwise C_ERR is returned and errmsg is set + * with an appropriate message. */ +int moduleUnload(sds name, const char **errmsg) { + struct ValkeyModule *module = dictFetchValue(modules, name); + + if (module == NULL) { + *errmsg = "no such module with that name"; + return C_ERR; + } + + return moduleUnloadInternal(module, errmsg); +} + +/* Unload all loaded modules from the server. + * + * This function iterates through all modules registered in the server's + * module dictionary and attempts to unload each one by calling + * moduleUnloadInternal(). If a module fails to unload (e.g., due to + * having active data types, blocked clients, or being used by other modules), + * the function logs a warning message but continues attempting to unload + * the remaining modules. + * + * This function is currently only called during server shutdown to ensure + * proper cleanup of all module resources. It attempts to unload all modules + * on a best-effort basis, and therefore the shutdown process is not interrupted + * by module unload failures. + */ +void moduleUnloadAllModules(void) { + dictIterator *di = dictGetSafeIterator(modules); + dictEntry *de; + + while ((de = dictNext(di)) != NULL) { + struct ValkeyModule *module = dictGetVal(de); + + const char *errmsg = NULL; + if (moduleUnloadInternal(module, &errmsg) == C_ERR) { + serverLog(LL_WARNING, "Failed to unload module %s: %s", module->name, errmsg); + } + } + dictReleaseIterator(di); +} + void modulePipeReadable(aeEventLoop *el, int fd, void *privdata, int mask) { UNUSED(el); UNUSED(fd); diff --git a/src/module.h b/src/module.h index f6c266b592..5371fee3c9 100644 --- a/src/module.h +++ b/src/module.h @@ -175,6 +175,8 @@ sds moduleLoadQueueEntryToLoadmoduleOptionStr(ValkeyModule *module, ValkeyModuleCtx *moduleAllocateContext(void); void moduleScriptingEngineInitContext(ValkeyModuleCtx *out_ctx, ValkeyModule *module, + int add_script_execution_flag, + int add_thread_safe_flag, client *client); void moduleFreeContext(ValkeyModuleCtx *ctx); void moduleInitModulesSystem(void); @@ -182,6 +184,7 @@ void moduleInitModulesSystemLast(void); void modulesCron(void); int moduleLoad(const char *path, void **argv, int argc, int is_loadex); int moduleUnload(sds name, const char **errmsg); +void moduleUnloadAllModules(void); void moduleLoadFromQueue(void); int moduleGetCommandKeysViaAPI(struct serverCommand *cmd, robj **argv, int argc, getKeysResult *result); int moduleGetCommandChannelsViaAPI(struct serverCommand *cmd, robj **argv, int argc, getKeysResult *result); diff --git a/src/modules/lua/CMakeLists.txt b/src/modules/lua/CMakeLists.txt new file mode 100644 index 0000000000..05f343ad2a --- /dev/null +++ b/src/modules/lua/CMakeLists.txt @@ -0,0 +1,24 @@ +project(valkeylua) + +if (VALKEY_DEBUG_BUILD) + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -W -Wall -fno-common -g -ggdb -std=c99 -O2 -D_GNU_SOURCE") +else () + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -W -Wall -fno-common -O3 -std=c99 -D_GNU_SOURCE") +endif() + +set(LUA_ENGINE_SRCS + engine_lua.c + script_lua.c + function_lua.c + debug_lua.c + list.c) + +add_library(valkeylua SHARED "${LUA_ENGINE_SRCS}") + +add_dependencies(valkeylua lualib) +target_link_libraries(valkeylua PRIVATE lualib) +target_include_directories(valkeylua PRIVATE ../../../deps/lua/src) + +install(TARGETS valkeylua + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} +) diff --git a/src/modules/lua/Makefile b/src/modules/lua/Makefile new file mode 100644 index 0000000000..08b5954f0e --- /dev/null +++ b/src/modules/lua/Makefile @@ -0,0 +1,45 @@ +uname_S := $(shell sh -c 'uname -s 2>/dev/null || echo not') + +DEPS_DIR=../../../deps + +ifeq ($(uname_S),Linux) + SHOBJ_CFLAGS= -I. -I$(DEPS_DIR)/lua/src -I$(DEPS_DIR)/fpconv -fPIC -W -Wall -fno-common $(OPTIMIZATION) -std=c99 -D_GNU_SOURCE $(CFLAGS) + SHOBJ_LDFLAGS= -shared $(LDFLAGS) +else + SHOBJ_CFLAGS= -I. -I$(DEPS_DIR)/lua/src -I$(DEPS_DIR)/fpconv -fPIC -W -Wall -dynamic -fno-common $(OPTIMIZATION) -std=c99 -D_GNU_SOURCE $(CFLAGS) + SHOBJ_LDFLAGS= -bundle -undefined dynamic_lookup $(LDFLAGS) +endif + +LIBS= $(DEPS_DIR)/lua/src/liblua.a $(DEPS_DIR)/fpconv/libfpconv.a +SRCS= $(wildcard *.c) +OBJS= $(SRCS:.c=.o) sha1.o rand.o + +# OS X 11.x doesn't have /usr/lib/libSystem.dylib and needs an explicit setting. +ifeq ($(uname_S),Darwin) +ifeq ("$(wildcard /usr/lib/libSystem.dylib)","") + SHOBJ_LDFLAGS+= -L /Library/Developer/CommandLineTools/SDKs/MacOSX.sdk/usr/lib -lsystem +endif +endif + +all: libvalkeylua.so + +libvalkeylua.so: $(OBJS) $(LIBS) + $(CC) -o $@ $(SHOBJ_LDFLAGS) $^ + +sha1.o: ../../sha1.c + $(CC) $(SHOBJ_CFLAGS) -c $< -o $@ + +rand.o: ../../rand.c + $(CC) $(SHOBJ_CFLAGS) -c $< -o $@ + +%.o: %.c + $(CC) $(SHOBJ_CFLAGS) -c $< -o $@ + +$(DEPS_DIR)/lua/src/liblua.a: + cd $(DEPS_DIR) && $(MAKE) lua + +$(DEPS_DIR)/fpconv/libfpconv.a: + cd $(DEPS_DIR) && $(MAKE) fpconv + +clean: + rm -f *.so $(OBJS) diff --git a/src/lua/debug_lua.c b/src/modules/lua/debug_lua.c similarity index 60% rename from src/lua/debug_lua.c rename to src/modules/lua/debug_lua.c index 9b295fb71b..0c6f56ab3d 100644 --- a/src/lua/debug_lua.c +++ b/src/modules/lua/debug_lua.c @@ -7,8 +7,8 @@ #include "debug_lua.h" #include "script_lua.h" -#include "../server.h" - +#include +#include #include #include #include @@ -25,7 +25,7 @@ struct ldbState { int bpcount; /* Number of valid entries inside bp. */ int step; /* Stop at next line regardless of breakpoints. */ int luabp; /* Stop at next line because server.breakpoint() was called. */ - sds *src; /* Lua script source code split by line. */ + char **src; /* Lua script source code split by line. */ int lines; /* Number of lines in 'src'. */ int currentline; /* Current line number. */ } ldb; @@ -61,33 +61,71 @@ void ldbDisable(void) { ldb.active = 0; } -void ldbStart(robj *source) { +static char **split_text_by_lines(const char *text, size_t len, int *lines) { + ValkeyModule_Assert(text != NULL && len > 0); + + int count = 1; + for (size_t i = 0; i < len; i++) { + if (text[i] == '\n') count++; + } + + char **result = ValkeyModule_Calloc(count, sizeof(char *)); + if (!result) { + ValkeyModule_Log(NULL, "error", "Failed to allocate memory for Lua source code lines."); + *lines = 0; + return NULL; + } + + size_t start = 0, idx = 0; + for (size_t i = 0; i <= len; i++) { + if (i == len || text[i] == '\n') { + size_t linelen = i - start; + char *line = ValkeyModule_Calloc(linelen + 1, 1); + if (line) { + memcpy(line, text + start, linelen); + line[linelen] = '\0'; + result[idx++] = line; + } + start = i + 1; + } + } + *lines = idx; + return result; +} + +void ldbStart(ValkeyModuleString *source) { ldb.active = 1; /* First argument of EVAL is the script itself. We split it into different * lines since this is the way the debugger accesses the source code. */ - sds srcstring = sdsdup(source->ptr); - size_t srclen = sdslen(srcstring); - while (srclen && (srcstring[srclen - 1] == '\n' || srcstring[srclen - 1] == '\r')) { - srcstring[--srclen] = '\0'; + size_t srclen; + const char *src_raw = ValkeyModule_StringPtrLen(source, &srclen); + while (srclen && (src_raw[srclen - 1] == '\n' || src_raw[srclen - 1] == '\r')) { + --srclen; } - sdssetlen(srcstring, srclen); - ldb.src = sdssplitlen(srcstring, sdslen(srcstring), "\n", 1, &ldb.lines); - sdsfree(srcstring); + ldb.src = split_text_by_lines(src_raw, srclen, &ldb.lines); } void ldbEnd(void) { - sdsfreesplitres(ldb.src, ldb.lines); + for (int i = 0; i < ldb.lines; i++) { + ValkeyModule_Free(ldb.src[i]); + } + ValkeyModule_Free(ldb.src); ldb.lines = 0; ldb.active = 0; } -void ldbLog(sds entry) { - scriptingEngineDebuggerLog(createObject(OBJ_STRING, entry)); +void ldbLog(ValkeyModuleString *entry) { + ValkeyModule_ScriptingEngineDebuggerLog(entry, 0); +} + +void ldbLogCString(const char *c_str) { + ValkeyModuleString *entry = ValkeyModule_CreateString(NULL, c_str, strlen(c_str)); + ldbLog(entry); } void ldbSendLogs(void) { - scriptingEngineDebuggerFlushLogs(); + ValkeyModule_ScriptingEngineDebuggerFlushLogs(); } /* Return a pointer to ldb.src source code line, considering line to be @@ -149,7 +187,7 @@ void ldbLogSourceLine(int lnum) { prefix = " #"; else prefix = " "; - sds thisline = sdscatprintf(sdsempty(), "%s%-3d %s", prefix, lnum, line); + ValkeyModuleString *thisline = ValkeyModule_CreateStringPrintf(NULL, "%s%-3d %s", prefix, lnum, line); ldbLog(thisline); } @@ -168,35 +206,49 @@ static void ldbList(int around, int context) { } /* Append a human readable representation of the Lua value at position 'idx' - * on the stack of the 'lua' state, to the SDS string passed as argument. - * The new SDS string with the represented value attached is returned. + * on the stack of the 'lua' state, to the string passed as argument. + * The new string with the represented value attached is returned. * Used in order to implement ldbLogStackValue(). * * The element is neither automatically removed from the stack, nor * converted to a different type. */ #define LDB_MAX_VALUES_DEPTH (LUA_MINSTACK / 2) -static sds ldbCatStackValueRec(sds s, lua_State *lua, int idx, int level) { +static ValkeyModuleString *ldbCatStackValueRec(ValkeyModuleString *s, lua_State *lua, int idx, int level) { int t = lua_type(lua, idx); + ValkeyModuleString *old_s = NULL; + const char *prefix = ValkeyModule_StringPtrLen(s, NULL); - if (level++ == LDB_MAX_VALUES_DEPTH) return sdscat(s, ""); + if (level++ == LDB_MAX_VALUES_DEPTH) { + const char *msg = ""; + ValkeyModule_StringAppendBuffer(NULL, s, msg, strlen(msg)); + return s; + } switch (t) { case LUA_TSTRING: { size_t strl; char *strp = (char *)lua_tolstring(lua, idx, &strl); - s = sdscatrepr(s, strp, strl); + ValkeyModule_StringAppendBuffer(NULL, s, strp, strl); } break; - case LUA_TBOOLEAN: s = sdscat(s, lua_toboolean(lua, idx) ? "true" : "false"); break; - case LUA_TNUMBER: s = sdscatprintf(s, "%g", (double)lua_tonumber(lua, idx)); break; - case LUA_TNIL: s = sdscatlen(s, "nil", 3); break; + case LUA_TBOOLEAN: { + const char *bool_str = lua_toboolean(lua, idx) ? "true" : "false"; + ValkeyModule_StringAppendBuffer(NULL, s, bool_str, strlen(bool_str)); + break; + } + case LUA_TNUMBER: { + old_s = s; + s = ValkeyModule_CreateStringPrintf(NULL, "%s %g", prefix, (double)lua_tonumber(lua, idx)); + break; + } + case LUA_TNIL: ValkeyModule_StringAppendBuffer(NULL, s, "nil", 3); break; case LUA_TTABLE: { int expected_index = 1; /* First index we expect in an array. */ int is_array = 1; /* Will be set to null if check fails. */ /* Note: we create two representations at the same time, one * assuming the table is an array, one assuming it is not. At the * end we know what is true and select the right one. */ - sds repr1 = sdsempty(); - sds repr2 = sdsempty(); + ValkeyModuleString *repr1 = ValkeyModule_CreateString(NULL, "", 0); + ValkeyModuleString *repr2 = ValkeyModule_CreateString(NULL, "", 0); lua_pushnil(lua); /* The first key to start the iteration is nil. */ while (lua_next(lua, idx - 1)) { /* Test if so far the table looks like an array. */ @@ -204,25 +256,27 @@ static sds ldbCatStackValueRec(sds s, lua_State *lua, int idx, int level) { /* Stack now: table, key, value */ /* Array repr. */ repr1 = ldbCatStackValueRec(repr1, lua, -1, level); - repr1 = sdscatlen(repr1, "; ", 2); + ValkeyModule_StringAppendBuffer(NULL, repr1, "; ", 2); /* Full repr. */ - repr2 = sdscatlen(repr2, "[", 1); + ValkeyModule_StringAppendBuffer(NULL, repr2, "[", 1); repr2 = ldbCatStackValueRec(repr2, lua, -2, level); - repr2 = sdscatlen(repr2, "]=", 2); + ValkeyModule_StringAppendBuffer(NULL, repr2, "]=", 2); repr2 = ldbCatStackValueRec(repr2, lua, -1, level); - repr2 = sdscatlen(repr2, "; ", 2); + ValkeyModule_StringAppendBuffer(NULL, repr2, "; ", 2); lua_pop(lua, 1); /* Stack: table, key. Ready for next iteration. */ expected_index++; } - /* Strip the last " ;" from both the representations. */ - if (sdslen(repr1)) sdsrange(repr1, 0, -3); - if (sdslen(repr2)) sdsrange(repr2, 0, -3); + /* Select the right one and discard the other. */ - s = sdscatlen(s, "{", 1); - s = sdscatsds(s, is_array ? repr1 : repr2); - s = sdscatlen(s, "}", 1); - sdsfree(repr1); - sdsfree(repr2); + ValkeyModule_StringAppendBuffer(NULL, s, "{", 1); + size_t repr1_len; + const char *repr1_str = ValkeyModule_StringPtrLen(repr1, &repr1_len); + size_t repr2_len; + const char *repr2_str = ValkeyModule_StringPtrLen(repr2, &repr2_len); + ValkeyModule_StringAppendBuffer(NULL, s, is_array ? repr1_str : repr2_str, is_array ? repr1_len : repr2_len); + ValkeyModule_StringAppendBuffer(NULL, s, "}", 1); + ValkeyModule_FreeString(NULL, repr1); + ValkeyModule_FreeString(NULL, repr2); } break; case LUA_TFUNCTION: case LUA_TUSERDATA: @@ -238,39 +292,50 @@ static sds ldbCatStackValueRec(sds s, lua_State *lua, int idx, int level) { typename = "thread"; else if (t == LUA_TLIGHTUSERDATA) typename = "light-userdata"; - s = sdscatprintf(s, "\"%s@%p\"", typename, p); + old_s = s; + s = ValkeyModule_CreateStringPrintf(NULL, "%s \"%s@%p\"", prefix, typename, p); } break; - default: s = sdscat(s, "\"\""); break; + default: { + const char *unknown_str = "\"\""; + ValkeyModule_StringAppendBuffer(NULL, s, unknown_str, strlen(unknown_str)); + break; + } + } + + if (old_s) { + ValkeyModule_FreeString(NULL, old_s); } + return s; } /* Higher level wrapper for ldbCatStackValueRec() that just uses an initial * recursion level of '0'. */ -sds ldbCatStackValue(sds s, lua_State *lua, int idx) { +ValkeyModuleString *ldbCatStackValue(ValkeyModuleString *s, lua_State *lua, int idx) { return ldbCatStackValueRec(s, lua, idx, 0); } /* Produce a debugger log entry representing the value of the Lua object * currently on the top of the stack. The element is neither popped nor modified. * Check ldbCatStackValue() for the actual implementation. */ -static void ldbLogStackValue(lua_State *lua, char *prefix) { - sds s = sdsnew(prefix); - s = ldbCatStackValue(s, lua, -1); - scriptingEngineDebuggerLogWithMaxLen(createObject(OBJ_STRING, s)); +static void ldbLogStackValue(lua_State *lua, const char *prefix) { + ValkeyModuleString *p = ValkeyModule_CreateString(NULL, prefix, strlen(prefix)); + ValkeyModuleString *s = ldbCatStackValue(p, lua, -1); + ValkeyModule_ScriptingEngineDebuggerLog(s, 1); + ValkeyModule_FreeString(NULL, s); } /* Log a RESP reply as debugger output, in a human readable format. * If the resulting string is longer than 'len' plus a few more chars * used as prefix, it gets truncated. */ void ldbLogRespReply(char *reply) { - scriptingEngineDebuggerLogRespReplyStr(reply); + ValkeyModule_ScriptingEngineDebuggerLogRespReplyStr(reply); } /* Implements the "print " command of the Lua debugger. It scans for Lua * var "varname" starting from the current stack frame up to the top stack * frame. The first matching variable is printed. */ -static void ldbPrint(lua_State *lua, char *varname) { +static void ldbPrint(lua_State *lua, const char *varname) { lua_Debug ar; int l = 0; /* Stack level. */ @@ -296,7 +361,7 @@ static void ldbPrint(lua_State *lua, char *varname) { ldbLogStackValue(lua, " "); lua_pop(lua, 1); } else { - ldbLog(sdsnew("No such variable.")); + ldbLogCString("No such variable."); } } @@ -312,9 +377,10 @@ static void ldbPrintAll(lua_State *lua) { while ((name = lua_getlocal(lua, &ar, i)) != NULL) { i++; if (!strstr(name, "(*temporary)")) { - sds prefix = sdscatprintf(sdsempty(), " %s = ", name); + char *prefix; + asprintf(&prefix, " %s = ", name); ldbLogStackValue(lua, prefix); - sdsfree(prefix); + free(prefix); vars++; } lua_pop(lua, 1); @@ -322,45 +388,52 @@ static void ldbPrintAll(lua_State *lua) { } if (vars == 0) { - ldbLog(sdsnew("No local variables in the current context.")); + ldbLogCString("No local variables in the current context."); } } /* Implements the break command to list, add and remove breakpoints. */ -static void ldbBreak(robj **argv, int argc) { +static void ldbBreak(ValkeyModuleString **argv, int argc) { if (argc == 1) { if (ldb.bpcount == 0) { - ldbLog(sdsnew("No breakpoints set. Use 'b ' to add one.")); + ldbLogCString("No breakpoints set. Use 'b ' to add one."); return; } else { - ldbLog(sdscatfmt(sdsempty(), "%i breakpoints set:", ldb.bpcount)); + char *msg; + asprintf(&msg, "%i breakpoints set:", ldb.bpcount); + ldbLogCString(msg); + free(msg); int j; for (j = 0; j < ldb.bpcount; j++) ldbLogSourceLine(ldb.bp[j]); } } else { int j; for (j = 1; j < argc; j++) { - char *arg = argv[j]->ptr; - long line; - if (!string2l(arg, sdslen(arg), &line)) { - ldbLog(sdscatfmt(sdsempty(), "Invalid argument:'%s'", arg)); + long long line; + int res = ValkeyModule_StringToLongLong(argv[j], &line); + if (res != VALKEYMODULE_OK) { + const char *arg = ValkeyModule_StringPtrLen(argv[j], NULL); + char *msg; + asprintf(&msg, "Invalid argument:'%s'", arg); + ldbLogCString(msg); + free(msg); } else { if (line == 0) { ldb.bpcount = 0; - ldbLog(sdsnew("All breakpoints removed.")); + ldbLogCString("All breakpoints removed."); } else if (line > 0) { if (ldb.bpcount == LDB_BREAKPOINTS_MAX) { - ldbLog(sdsnew("Too many breakpoints set.")); + ldbLogCString("Too many breakpoints set."); } else if (ldbAddBreakpoint(line)) { ldbList(line, 1); } else { - ldbLog(sdsnew("Wrong line number.")); + ldbLogCString("Wrong line number."); } } else if (line < 0) { if (ldbDelBreakpoint(-line)) - ldbLog(sdsnew("Breakpoint removed.")); + ldbLogCString("Breakpoint removed."); else - ldbLog(sdsnew("No breakpoint in the specified line.")); + ldbLogCString("No breakpoint in the specified line."); } } } @@ -370,33 +443,49 @@ static void ldbBreak(robj **argv, int argc) { /* Implements the Lua debugger "eval" command. It just compiles the user * passed fragment of code and executes it, showing the result left on * the stack. */ -static void ldbEval(lua_State *lua, robj **argv, int argc) { +static void ldbEval(lua_State *lua, ValkeyModuleString **argv, int argc) { /* Glue the script together if it is composed of multiple arguments. */ - sds code = sdsempty(); + ValkeyModuleString *code = ValkeyModule_CreateString(NULL, "", 0); for (int j = 1; j < argc; j++) { - code = sdscatsds(code, argv[j]->ptr); - if (j != argc - 1) code = sdscatlen(code, " ", 1); + size_t arglen; + const char *arg = ValkeyModule_StringPtrLen(argv[j], &arglen); + ValkeyModule_StringAppendBuffer(NULL, code, arg, arglen); + if (j != argc - 1) { + ValkeyModule_StringAppendBuffer(NULL, code, " ", 1); + } } - sds expr = sdscatsds(sdsnew("return "), code); + + ValkeyModuleString *expr = ValkeyModule_CreateStringPrintf(NULL, "return %s", ValkeyModule_StringPtrLen(code, NULL)); + + size_t code_len; + const char *code_str = ValkeyModule_StringPtrLen(code, &code_len); + + size_t expr_len; + const char *expr_str = ValkeyModule_StringPtrLen(expr, &expr_len); /* Try to compile it as an expression, prepending "return ". */ - if (luaL_loadbuffer(lua, expr, sdslen(expr), "@ldb_eval")) { + if (luaL_loadbuffer(lua, expr_str, expr_len, "@ldb_eval")) { lua_pop(lua, 1); /* Failed? Try as a statement. */ - if (luaL_loadbuffer(lua, code, sdslen(code), "@ldb_eval")) { - ldbLog(sdscatfmt(sdsempty(), " %s", lua_tostring(lua, -1))); - lua_pop(lua, 1); - sdsfree(code); - sdsfree(expr); + if (luaL_loadbuffer(lua, code_str, code_len, "@ldb_eval")) { + char *err_msg; + asprintf(&err_msg, "Error compiling code: %s", lua_tostring(lua, -1)); + ldbLogCString(err_msg); + free(err_msg); + ValkeyModule_FreeString(NULL, code); + ValkeyModule_FreeString(NULL, expr); return; } } /* Call it. */ - sdsfree(code); - sdsfree(expr); + ValkeyModule_FreeString(NULL, code); + ValkeyModule_FreeString(NULL, expr); if (lua_pcall(lua, 0, 1, 0)) { - ldbLog(sdscatfmt(sdsempty(), " %s", lua_tostring(lua, -1))); + char *err_msg; + asprintf(&err_msg, " %s", lua_tostring(lua, -1)); + ldbLogCString(err_msg); + free(err_msg); lua_pop(lua, 1); return; } @@ -408,7 +497,7 @@ static void ldbEval(lua_State *lua, robj **argv, int argc) { * the implementation very simple: we just call the Lua server.call() command * implementation, with ldb.step enabled, so as a side effect the command * and its reply are logged. */ -static void ldbServer(lua_State *lua, robj **argv, int argc) { +static void ldbServer(lua_State *lua, ValkeyModuleString **argv, int argc) { int j; if (!lua_checkstack(lua, argc + 1)) { @@ -425,8 +514,11 @@ static void ldbServer(lua_State *lua, robj **argv, int argc) { lua_getglobal(lua, "server"); lua_pushstring(lua, "call"); lua_gettable(lua, -2); /* Stack: server, server.call */ - for (j = 1; j < argc; j++) - lua_pushlstring(lua, argv[j]->ptr, sdslen(argv[j]->ptr)); + for (j = 1; j < argc; j++) { + size_t arg_len; + const char *arg = ValkeyModule_StringPtrLen(argv[j], &arg_len); + lua_pushlstring(lua, arg, arg_len); + } ldb.step = 1; /* Force server.call() to log. */ lua_pcall(lua, argc - 1, 1, 0); /* Stack: server, result */ ldb.step = 0; /* Disable logging. */ @@ -442,141 +534,144 @@ static void ldbTrace(lua_State *lua) { while (lua_getstack(lua, level, &ar)) { lua_getinfo(lua, "Snl", &ar); if (strstr(ar.short_src, "user_script") != NULL) { - ldbLog(sdscatprintf(sdsempty(), "%s %s:", (level == 0) ? "In" : "From", ar.name ? ar.name : "top level")); + char *msg; + asprintf(&msg, "%s %s:", (level == 0) ? "In" : "From", ar.name ? ar.name : "top level"); + ldbLogCString(msg); + free(msg); ldbLogSourceLine(ar.currentline); } level++; } if (level == 0) { - ldbLog(sdsnew(" Can't retrieve Lua stack.")); + ldbLogCString(" Can't retrieve Lua stack."); } } #define CONTINUE_SCRIPT_EXECUTION 0 #define CONTINUE_READ_NEXT_COMMAND 1 -static int stepCommandHandler(robj **argv, size_t argc, void *context) { - UNUSED(argv); - UNUSED(argc); - UNUSED(context); +static int stepCommandHandler(ValkeyModuleString **argv, size_t argc, void *context) { + VALKEYMODULE_NOT_USED(argv); + VALKEYMODULE_NOT_USED(argc); + VALKEYMODULE_NOT_USED(context); ldb.step = 1; return CONTINUE_SCRIPT_EXECUTION; } -static int continueCommandHandler(robj **argv, size_t argc, void *context) { - UNUSED(argv); - UNUSED(argc); - UNUSED(context); +static int continueCommandHandler(ValkeyModuleString **argv, size_t argc, void *context) { + VALKEYMODULE_NOT_USED(argv); + VALKEYMODULE_NOT_USED(argc); + VALKEYMODULE_NOT_USED(context); return CONTINUE_SCRIPT_EXECUTION; } -static int listCommandHandler(robj **argv, size_t argc, void *context) { - UNUSED(context); +static int listCommandHandler(ValkeyModuleString **argv, size_t argc, void *context) { + VALKEYMODULE_NOT_USED(context); int around = ldb.currentline, ctx = 5; if (argc > 1) { - int num = atoi(argv[1]->ptr); + int num = atoi(ValkeyModule_StringPtrLen(argv[1], NULL)); if (num > 0) around = num; } - if (argc > 2) ctx = atoi(argv[2]->ptr); + if (argc > 2) ctx = atoi(ValkeyModule_StringPtrLen(argv[2], NULL)); ldbList(around, ctx); - scriptingEngineDebuggerFlushLogs(); + ValkeyModule_ScriptingEngineDebuggerFlushLogs(); return CONTINUE_READ_NEXT_COMMAND; } -static int wholeCommandHandler(robj **argv, size_t argc, void *context) { - UNUSED(argv); - UNUSED(argc); - UNUSED(context); +static int wholeCommandHandler(ValkeyModuleString **argv, size_t argc, void *context) { + VALKEYMODULE_NOT_USED(argv); + VALKEYMODULE_NOT_USED(argc); + VALKEYMODULE_NOT_USED(context); ldbList(1, 1000000); - scriptingEngineDebuggerFlushLogs(); + ValkeyModule_ScriptingEngineDebuggerFlushLogs(); return CONTINUE_READ_NEXT_COMMAND; } -static int printCommandHandler(robj **argv, size_t argc, void *context) { - serverAssert(context != NULL); +static int printCommandHandler(ValkeyModuleString **argv, size_t argc, void *context) { + ValkeyModule_Assert(context != NULL); lua_State *lua = context; if (argc == 2) { - ldbPrint(lua, argv[1]->ptr); + ldbPrint(lua, ValkeyModule_StringPtrLen(argv[1], NULL)); } else { ldbPrintAll(lua); } - scriptingEngineDebuggerFlushLogs(); + ValkeyModule_ScriptingEngineDebuggerFlushLogs(); return CONTINUE_READ_NEXT_COMMAND; } -static int breakCommandHandler(robj **argv, size_t argc, void *context) { - UNUSED(context); +static int breakCommandHandler(ValkeyModuleString **argv, size_t argc, void *context) { + VALKEYMODULE_NOT_USED(context); ldbBreak(argv, argc); - scriptingEngineDebuggerFlushLogs(); + ValkeyModule_ScriptingEngineDebuggerFlushLogs(); return CONTINUE_READ_NEXT_COMMAND; } -static int traceCommandHandler(robj **argv, size_t argc, void *context) { - UNUSED(argv); - UNUSED(argc); - UNUSED(context); +static int traceCommandHandler(ValkeyModuleString **argv, size_t argc, void *context) { + VALKEYMODULE_NOT_USED(argv); + VALKEYMODULE_NOT_USED(argc); + VALKEYMODULE_NOT_USED(context); lua_State *lua = context; ldbTrace(lua); - scriptingEngineDebuggerFlushLogs(); + ValkeyModule_ScriptingEngineDebuggerFlushLogs(); return CONTINUE_READ_NEXT_COMMAND; } -static int evalCommandHandler(robj **argv, size_t argc, void *context) { - serverAssert(context != NULL); +static int evalCommandHandler(ValkeyModuleString **argv, size_t argc, void *context) { + ValkeyModule_Assert(context != NULL); lua_State *lua = context; ldbEval(lua, argv, argc); - scriptingEngineDebuggerFlushLogs(); + ValkeyModule_ScriptingEngineDebuggerFlushLogs(); return CONTINUE_READ_NEXT_COMMAND; } -static int valkeyCommandHandler(robj **argv, size_t argc, void *context) { - serverAssert(context != NULL); +static int valkeyCommandHandler(ValkeyModuleString **argv, size_t argc, void *context) { + ValkeyModule_Assert(context != NULL); lua_State *lua = context; ldbServer(lua, argv, argc); - scriptingEngineDebuggerFlushLogs(); + ValkeyModule_ScriptingEngineDebuggerFlushLogs(); return CONTINUE_READ_NEXT_COMMAND; } -static int abortCommandHandler(robj **argv, size_t argc, void *context) { - UNUSED(argv); - UNUSED(argc); - UNUSED(context); - serverAssert(context != NULL); +static int abortCommandHandler(ValkeyModuleString **argv, size_t argc, void *context) { + VALKEYMODULE_NOT_USED(argv); + VALKEYMODULE_NOT_USED(argc); + VALKEYMODULE_NOT_USED(context); + ValkeyModule_Assert(context != NULL); lua_State *lua = context; luaPushError(lua, "script aborted for user request"); luaError(lua); return CONTINUE_READ_NEXT_COMMAND; } -static debuggerCommand *commands_array_cache = NULL; +static ValkeyModuleScriptingEngineDebuggerCommand *commands_array_cache = NULL; static size_t commands_array_len = 0; void ldbGenerateDebuggerCommandsArray(lua_State *lua, - const debuggerCommand **commands, + const ValkeyModuleScriptingEngineDebuggerCommand **commands, size_t *commands_len) { - static debuggerCommandParam list_params[] = { + static ValkeyModuleScriptingEngineDebuggerCommandParam list_params[] = { {.name = "line", .optional = 1}, {.name = "ctx", .optional = 1}, }; - static debuggerCommandParam print_params[] = { + static ValkeyModuleScriptingEngineDebuggerCommandParam print_params[] = { {.name = "var", .optional = 1}, }; - static debuggerCommandParam break_params[] = { + static ValkeyModuleScriptingEngineDebuggerCommandParam break_params[] = { {.name = "line|-line", .optional = 1}, }; - static debuggerCommandParam eval_params[] = { + static ValkeyModuleScriptingEngineDebuggerCommandParam eval_params[] = { {.name = "code", .optional = 0, .variadic = 1}, }; - static debuggerCommandParam valkey_params[] = { + static ValkeyModuleScriptingEngineDebuggerCommandParam valkey_params[] = { {.name = "cmd", .optional = 0, .variadic = 1}, }; if (commands_array_cache == NULL) { - debuggerCommand commands_array[] = { + ValkeyModuleScriptingEngineDebuggerCommand commands_array[] = { VALKEYMODULE_SCRIPTING_ENGINE_DEBUGGER_COMMAND("step", 1, NULL, 0, "Run current line and stop again.", 0, stepCommandHandler), VALKEYMODULE_SCRIPTING_ENGINE_DEBUGGER_COMMAND("next", 1, NULL, 0, "Alias for step.", 0, stepCommandHandler), VALKEYMODULE_SCRIPTING_ENGINE_DEBUGGER_COMMAND("continue", 1, NULL, 0, "Run till next breakpoint.", 0, continueCommandHandler), @@ -592,9 +687,9 @@ void ldbGenerateDebuggerCommandsArray(lua_State *lua, VALKEYMODULE_SCRIPTING_ENGINE_DEBUGGER_COMMAND_WITH_CTX("abort", 1, NULL, 0, "Stop the execution of the script. In sync mode dataset changes will be retained.", 0, abortCommandHandler, lua), }; - commands_array_len = sizeof(commands_array) / sizeof(debuggerCommand); + commands_array_len = sizeof(commands_array) / sizeof(ValkeyModuleScriptingEngineDebuggerCommand); - commands_array_cache = zmalloc(sizeof(debuggerCommand) * commands_array_len); + commands_array_cache = ValkeyModule_Calloc(commands_array_len, sizeof(ValkeyModuleScriptingEngineDebuggerCommand)); memcpy(commands_array_cache, &commands_array, sizeof(commands_array)); } @@ -607,13 +702,14 @@ void ldbGenerateDebuggerCommandsArray(lua_State *lua, * C_ERR if the client closed the connection or is timing out. */ int ldbRepl(lua_State *lua) { int client_disconnected = 0; - robj *err = NULL; + ValkeyModuleString *err = NULL; - scriptingEngineDebuggerProcessCommands(&client_disconnected, &err); + ValkeyModule_ScriptingEngineDebuggerProcessCommands(&client_disconnected, &err); if (err) { - luaPushError(lua, err->ptr); - decrRefCount(err); + const char *err_msg = ValkeyModule_StringPtrLen(err, NULL); + luaPushError(lua, err_msg); + ValkeyModule_Free(err); luaError(lua); } else if (client_disconnected) { /* Make sure the script runs without user input since the diff --git a/src/lua/debug_lua.h b/src/modules/lua/debug_lua.h similarity index 68% rename from src/lua/debug_lua.h rename to src/modules/lua/debug_lua.h index c197bae9ea..b0a083cbde 100644 --- a/src/lua/debug_lua.h +++ b/src/modules/lua/debug_lua.h @@ -1,10 +1,8 @@ #ifndef _LUA_DEBUG_H_ #define _LUA_DEBUG_H_ -#include "../scripting_engine.h" +#include "../../valkeymodule.h" -typedef char *sds; -typedef struct serverObject robj; typedef struct lua_State lua_State; typedef struct client client; @@ -13,16 +11,17 @@ int ldbIsEnabled(void); void ldbDisable(void); void ldbEnable(void); int ldbIsActive(void); -void ldbStart(robj *source); +void ldbStart(ValkeyModuleString *source); void ldbEnd(void); -void ldbLog(sds entry); +void ldbLog(ValkeyModuleString *entry); +void ldbLogCString(const char *c_str); void ldbSendLogs(void); void ldbLogRespReply(char *reply); int ldbGetCurrentLine(void); void ldbSetCurrentLine(int line); void ldbLogSourceLine(int lnum); -sds ldbCatStackValue(sds s, lua_State *lua, int idx); +ValkeyModuleString *ldbCatStackValue(ValkeyModuleString *s, lua_State *lua, int idx); void ldbSetBreakpointOnNextLine(int enable); int ldbIsBreakpointOnNextLineEnabled(void); int ldbShouldBreak(void); @@ -31,7 +30,7 @@ void ldbSetStepMode(int enable); int ldbRepl(lua_State *lua); void ldbGenerateDebuggerCommandsArray(lua_State *lua, - const debuggerCommand **commands, + const ValkeyModuleScriptingEngineDebuggerCommand **commands, size_t *commands_len); #endif /* _LUA_DEBUG_H_ */ diff --git a/src/modules/lua/engine_lua.c b/src/modules/lua/engine_lua.c new file mode 100644 index 0000000000..6cb736c806 --- /dev/null +++ b/src/modules/lua/engine_lua.c @@ -0,0 +1,546 @@ +/* + * Copyright (c) Valkey Contributors + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +#include "../../valkeymodule.h" +#include +#include +#include +#include +#if defined(__GLIBC__) && !defined(USE_LIBC) +#include +#endif +#include + +#include "engine_structs.h" +#include "function_lua.h" +#include "script_lua.h" +#include "debug_lua.h" + + +#define LUA_ENGINE_NAME "LUA" +#define REGISTRY_ERROR_HANDLER_NAME "__ERROR_HANDLER__" + +/* Adds server.debug() function used by lua debugger + * + * Log a string message into the output console. + * Can take multiple arguments that will be separated by commas. + * Nothing is returned to the caller. */ +static int luaServerDebugCommand(lua_State *lua) { + if (!ldbIsActive()) return 0; + int argc = lua_gettop(lua); + ValkeyModuleString *log = ValkeyModule_CreateStringPrintf(NULL, " line %d: ", ldbGetCurrentLine()); + while (argc--) { + log = ldbCatStackValue(log, lua, -1 - argc); + if (argc != 0) { + ValkeyModule_StringAppendBuffer(NULL, log, ", ", 2); + } + } + ldbLog(log); + return 0; +} + +/* Adds server.breakpoint() function used by lua debugger. + * + * Allows to stop execution during a debugging session from within + * the Lua code implementation, like if a breakpoint was set in the code + * immediately after the function. */ +static int luaServerBreakpointCommand(lua_State *lua) { + if (ldbIsActive()) { + ldbSetBreakpointOnNextLine(1); + lua_pushboolean(lua, 1); + } else { + lua_pushboolean(lua, 0); + } + return 1; +} + +/* Adds server.replicate_commands() + * + * DEPRECATED: Now do nothing and always return true. + * Turn on single commands replication if the script never called + * a write command so far, and returns true. Otherwise if the script + * already started to write, returns false and stick to whole scripts + * replication, which is our default. */ +int luaServerReplicateCommandsCommand(lua_State *lua) { + lua_pushboolean(lua, 1); + return 1; +} + +static void luaStateInstallErrorHandler(lua_State *lua) { + /* Add a helper function we use for pcall error reporting. + * Note that when the error is in the C function we want to report the + * information about the caller, that's what makes sense from the point + * of view of the user debugging a script. */ + lua_pushstring(lua, REGISTRY_ERROR_HANDLER_NAME); + char *errh_func = "local dbg = debug\n" + "debug = nil\n" + "local error_handler = function (err)\n" + " local i = dbg.getinfo(2,'nSl')\n" + " if i and i.what == 'C' then\n" + " i = dbg.getinfo(3,'nSl')\n" + " end\n" + " if type(err) ~= 'table' then\n" + " err = {err='ERR ' .. tostring(err)}" + " end" + " if i then\n" + " err['source'] = i.source\n" + " err['line'] = i.currentline\n" + " end" + " return err\n" + "end\n" + "return error_handler"; + luaL_loadbuffer(lua, errh_func, strlen(errh_func), "@err_handler_def"); + lua_pcall(lua, 0, 1, 0); + lua_settable(lua, LUA_REGISTRYINDEX); +} + +static void luaStateLockGlobalTable(lua_State *lua) { + /* Lock the global table from any changes */ + lua_pushvalue(lua, LUA_GLOBALSINDEX); + luaSetErrorMetatable(lua); + /* Recursively lock all tables that can be reached from the global table */ + luaSetTableProtectionRecursively(lua); + lua_pop(lua, 1); + /* Set metatables of basic types (string, number, nil etc.) readonly. */ + luaSetTableProtectionForBasicTypes(lua); +} + + +static void initializeEvalLuaState(lua_State *lua) { + /* register debug commands. we only need to add it under 'server' as 'redis' + * is effectively aliased to 'server' table at this point. */ + lua_getglobal(lua, "server"); + + /* server.breakpoint */ + lua_pushstring(lua, "breakpoint"); + lua_pushcfunction(lua, luaServerBreakpointCommand); + lua_settable(lua, -3); + + /* server.debug */ + lua_pushstring(lua, "debug"); + lua_pushcfunction(lua, luaServerDebugCommand); + lua_settable(lua, -3); + + /* server.replicate_commands */ + lua_pushstring(lua, "replicate_commands"); + lua_pushcfunction(lua, luaServerReplicateCommandsCommand); + lua_settable(lua, -3); + + lua_setglobal(lua, "server"); + + /* Duplicate the function with __server__err__hanler and + * __redis__err_handler name for backwards compatibility. */ + lua_pushstring(lua, REGISTRY_ERROR_HANDLER_NAME); + lua_gettable(lua, LUA_REGISTRYINDEX); + lua_setglobal(lua, "__server__err__handler"); + lua_getglobal(lua, "__server__err__handler"); + lua_setglobal(lua, "__redis__err__handler"); +} + +static uint32_t parse_semver(const char *version) { + unsigned int major = 0, minor = 0, patch = 0; + sscanf(version, "%u.%u.%u", &major, &minor, &patch); + return ((major & 0xFF) << 16) | ((minor & 0xFF) << 8) | (patch & 0xFF); +} + +static void get_version_info(ValkeyModuleCtx *ctx, + char **redis_version, + uint32_t *redis_version_num, + char **server_name, + char **valkey_version, + uint32_t *valkey_version_num) { + ValkeyModuleServerInfoData *info = ValkeyModule_GetServerInfo(ctx, "server"); + ValkeyModule_Assert(info != NULL); + + const char *rv = ValkeyModule_ServerInfoGetFieldC(info, "redis_version"); + *redis_version = lm_strcpy(rv); + *redis_version_num = parse_semver(*redis_version); + + const char *sn = ValkeyModule_ServerInfoGetFieldC(info, "server_name"); + *server_name = lm_strcpy(sn); + + const char *vv = ValkeyModule_ServerInfoGetFieldC(info, "valkey_version"); + *valkey_version = lm_strcpy(vv); + *valkey_version_num = parse_semver(*valkey_version); + + ValkeyModule_FreeServerInfo(ctx, info); +} + +static void initializeLuaState(luaEngineCtx *lua_engine_ctx, + ValkeyModuleScriptingEngineSubsystemType type) { + lua_State *lua = lua_open(); + + if (type == VMSE_EVAL) { + lua_engine_ctx->eval_lua = lua; + } else { + ValkeyModule_Assert(type == VMSE_FUNCTION); + lua_engine_ctx->function_lua = lua; + } + + luaRegisterServerAPI(lua_engine_ctx, lua); + luaStateInstallErrorHandler(lua); + + if (type == VMSE_EVAL) { + initializeEvalLuaState(lua); + luaStateLockGlobalTable(lua); + } else { + luaStateLockGlobalTable(lua); + luaFunctionInitializeLuaState(lua_engine_ctx, lua); + } +} + +static struct luaEngineCtx *createEngineContext(ValkeyModuleCtx *ctx) { + luaEngineCtx *lua_engine_ctx = ValkeyModule_Alloc(sizeof(*lua_engine_ctx)); + + get_version_info(ctx, + &lua_engine_ctx->redis_version, + &lua_engine_ctx->redis_version_num, + &lua_engine_ctx->server_name, + &lua_engine_ctx->valkey_version, + &lua_engine_ctx->valkey_version_num); + + lua_engine_ctx->lua_enable_insecure_api = 0; + + initializeLuaState(lua_engine_ctx, VMSE_EVAL); + initializeLuaState(lua_engine_ctx, VMSE_FUNCTION); + + return lua_engine_ctx; +} + +static void destroyEngineContext(luaEngineCtx *lua_engine_ctx) { + lua_close(lua_engine_ctx->eval_lua); + lua_close(lua_engine_ctx->function_lua); + ValkeyModule_Free(lua_engine_ctx->redis_version); + ValkeyModule_Free(lua_engine_ctx->server_name); + ValkeyModule_Free(lua_engine_ctx->valkey_version); + ValkeyModule_Free(lua_engine_ctx); +} + +static ValkeyModuleScriptingEngineMemoryInfo luaEngineGetMemoryInfo(ValkeyModuleCtx *module_ctx, + ValkeyModuleScriptingEngineCtx *engine_ctx, + ValkeyModuleScriptingEngineSubsystemType type) { + VALKEYMODULE_NOT_USED(module_ctx); + luaEngineCtx *lua_engine_ctx = engine_ctx; + ValkeyModuleScriptingEngineMemoryInfo mem_info = {0}; + + if (type == VMSE_EVAL || type == VMSE_ALL) { + mem_info.used_memory += luaMemory(lua_engine_ctx->eval_lua); + } + if (type == VMSE_FUNCTION || type == VMSE_ALL) { + mem_info.used_memory += luaMemory(lua_engine_ctx->function_lua); + } + + mem_info.engine_memory_overhead = ValkeyModule_MallocSize(engine_ctx); + + return mem_info; +} + +static ValkeyModuleScriptingEngineCompiledFunction **luaEngineCompileCode(ValkeyModuleCtx *module_ctx, + ValkeyModuleScriptingEngineCtx *engine_ctx, + ValkeyModuleScriptingEngineSubsystemType type, + const char *code, + size_t code_len, + size_t timeout, + size_t *out_num_compiled_functions, + ValkeyModuleString **err) { + luaEngineCtx *lua_engine_ctx = (luaEngineCtx *)engine_ctx; + ValkeyModuleScriptingEngineCompiledFunction **functions = NULL; + + if (type == VMSE_EVAL) { + lua_State *lua = lua_engine_ctx->eval_lua; + + if (luaL_loadbuffer( + lua, code, code_len, "@user_script")) { + *err = ValkeyModule_CreateStringPrintf(module_ctx, "Error compiling script (new function): %s", lua_tostring(lua, -1)); + lua_pop(lua, 1); + return functions; + } + + ValkeyModule_Assert(lua_isfunction(lua, -1)); + int function_ref = luaL_ref(lua, LUA_REGISTRYINDEX); + + luaFunction *script = ValkeyModule_Calloc(1, sizeof(luaFunction)); + *script = (luaFunction){ + .lua = lua, + .function_ref = function_ref, + }; + + ValkeyModuleScriptingEngineCompiledFunction *func = ValkeyModule_Alloc(sizeof(*func)); + *func = (ValkeyModuleScriptingEngineCompiledFunction){ + .name = NULL, + .function = script, + .desc = NULL, + .f_flags = 0}; + + *out_num_compiled_functions = 1; + functions = ValkeyModule_Calloc(1, sizeof(ValkeyModuleScriptingEngineCompiledFunction *)); + *functions = func; + } else { + functions = luaFunctionLibraryCreate(lua_engine_ctx->function_lua, + code, + timeout, + out_num_compiled_functions, + err); + } + + return functions; +} + +static void luaEngineFunctionCall(ValkeyModuleCtx *module_ctx, + ValkeyModuleScriptingEngineCtx *engine_ctx, + ValkeyModuleScriptingEngineServerRuntimeCtx *server_ctx, + ValkeyModuleScriptingEngineCompiledFunction *compiled_function, + ValkeyModuleScriptingEngineSubsystemType type, + ValkeyModuleString **keys, + size_t nkeys, + ValkeyModuleString **args, + size_t nargs) { + luaEngineCtx *lua_engine_ctx = (luaEngineCtx *)engine_ctx; + lua_State *lua = type == VMSE_EVAL ? lua_engine_ctx->eval_lua : lua_engine_ctx->function_lua; + luaFunction *script = compiled_function->function; + int lua_function_ref = script->function_ref; + + /* Push the pcall error handler function on the stack. */ + lua_pushstring(lua, REGISTRY_ERROR_HANDLER_NAME); + lua_gettable(lua, LUA_REGISTRYINDEX); + + lua_rawgeti(lua, LUA_REGISTRYINDEX, lua_function_ref); + ValkeyModule_Assert(!lua_isnil(lua, -1)); + + luaCallFunction(module_ctx, + server_ctx, + type, + lua, + keys, + nkeys, + args, + nargs, + type == VMSE_EVAL ? ldbIsActive() : 0, + lua_engine_ctx->lua_enable_insecure_api); + + lua_pop(lua, 1); /* Remove the error handler. */ +} + +static void resetLuaContext(void *context) { + lua_State *lua = context; + lua_gc(lua, LUA_GCCOLLECT, 0); + lua_close(lua); + +#if defined(__GLIBC__) && !defined(USE_LIBC) + /* The lua interpreter may hold a lot of memory internally, and lua is + * using libc. libc may take a bit longer to return the memory to the OS, + * so after lua_close, we call malloc_trim try to purge it earlier. + * + * We do that only when the server itself does not use libc. When Lua and the server + * use different allocators, one won't use the fragmentation holes of the + * other, and released memory can take a long time until it is returned to + * the OS. */ + malloc_trim(0); +#endif +} + +static int isLuaInsecureAPIEnabled(ValkeyModuleCtx *module_ctx) { + int result = 0; + ValkeyModuleCallReply *reply = ValkeyModule_Call(module_ctx, "CONFIG", "ccE", "GET", "lua-enable-insecure-api"); + if (ValkeyModule_CallReplyType(reply) == VALKEYMODULE_REPLY_ERROR) { + ValkeyModule_Log(module_ctx, + "warning", + "Unable to determine 'lua-enable-insecure-api' configuration value: %s", + ValkeyModule_CallReplyStringPtr(reply, NULL)); + ValkeyModule_FreeCallReply(reply); + return 0; + } + ValkeyModule_Assert(ValkeyModule_CallReplyType(reply) == VALKEYMODULE_REPLY_ARRAY && + ValkeyModule_CallReplyLength(reply) == 2); + ValkeyModuleCallReply *val = ValkeyModule_CallReplyArrayElement(reply, 1); + ValkeyModule_Assert(ValkeyModule_CallReplyType(val) == VALKEYMODULE_REPLY_STRING); + const char *val_str = ValkeyModule_CallReplyStringPtr(val, NULL); + result = strncmp(val_str, "yes", 3) == 0; + ValkeyModule_FreeCallReply(reply); + return result; +} + +static ValkeyModuleScriptingEngineCallableLazyEnvReset *luaEngineResetEnv(ValkeyModuleCtx *module_ctx, + ValkeyModuleScriptingEngineCtx *engine_ctx, + ValkeyModuleScriptingEngineSubsystemType type, + int async) { + VALKEYMODULE_NOT_USED(module_ctx); + luaEngineCtx *lua_engine_ctx = (luaEngineCtx *)engine_ctx; + ValkeyModule_Assert(type == VMSE_EVAL || type == VMSE_FUNCTION); + lua_State *lua = type == VMSE_EVAL ? lua_engine_ctx->eval_lua : lua_engine_ctx->function_lua; + ValkeyModule_Assert(lua); + ValkeyModuleScriptingEngineCallableLazyEnvReset *callback = NULL; + + if (async) { + callback = ValkeyModule_Calloc(1, sizeof(*callback)); + *callback = (ValkeyModuleScriptingEngineCallableLazyEnvReset){ + .context = lua, + .engineLazyEnvResetCallback = resetLuaContext, + }; + } else { + resetLuaContext(lua); + } + + lua_engine_ctx->lua_enable_insecure_api = isLuaInsecureAPIEnabled(module_ctx); + + initializeLuaState(lua_engine_ctx, type); + + return callback; +} + +static size_t luaEngineFunctionMemoryOverhead(ValkeyModuleCtx *module_ctx, + ValkeyModuleScriptingEngineCompiledFunction *compiled_function) { + VALKEYMODULE_NOT_USED(module_ctx); + return ValkeyModule_MallocSize(compiled_function->function) + + (compiled_function->name ? ValkeyModule_MallocSize(compiled_function->name) : 0) + + (compiled_function->desc ? ValkeyModule_MallocSize(compiled_function->desc) : 0) + + ValkeyModule_MallocSize(compiled_function); +} + +static void luaEngineFreeFunction(ValkeyModuleCtx *module_ctx, + ValkeyModuleScriptingEngineCtx *engine_ctx, + ValkeyModuleScriptingEngineSubsystemType type, + ValkeyModuleScriptingEngineCompiledFunction *compiled_function) { + VALKEYMODULE_NOT_USED(module_ctx); + ValkeyModule_Assert(type == VMSE_EVAL || type == VMSE_FUNCTION); + + luaEngineCtx *lua_engine_ctx = engine_ctx; + lua_State *lua = type == VMSE_EVAL ? lua_engine_ctx->eval_lua : lua_engine_ctx->function_lua; + ValkeyModule_Assert(lua); + + luaFunction *script = (luaFunction *)compiled_function->function; + if (lua == script->lua) { + /* The lua context is still the same, which means that we're not + * resetting the whole eval context, and therefore, we need to + * delete the function from the lua context. + */ + lua_unref(lua, script->function_ref); + } + ValkeyModule_Free(script); + + if (compiled_function->name) { + ValkeyModule_Free(compiled_function->name); + } + if (compiled_function->desc) { + ValkeyModule_Free(compiled_function->desc); + } + ValkeyModule_Free(compiled_function); +} + +static ValkeyModuleScriptingEngineDebuggerEnableRet luaEngineDebuggerEnable(ValkeyModuleCtx *module_ctx, + ValkeyModuleScriptingEngineCtx *engine_ctx, + ValkeyModuleScriptingEngineSubsystemType type, + const ValkeyModuleScriptingEngineDebuggerCommand **commands, + size_t *commands_len) { + VALKEYMODULE_NOT_USED(module_ctx); + + if (type != VMSE_EVAL) { + return VMSE_DEBUG_NOT_SUPPORTED; + } + + ldbEnable(); + + luaEngineCtx *lua_engine_ctx = engine_ctx; + ldbGenerateDebuggerCommandsArray(lua_engine_ctx->eval_lua, + commands, + commands_len); + + return VMSE_DEBUG_ENABLED; +} + +static void luaEngineDebuggerDisable(ValkeyModuleCtx *module_ctx, + ValkeyModuleScriptingEngineCtx *engine_ctx, + ValkeyModuleScriptingEngineSubsystemType type) { + VALKEYMODULE_NOT_USED(module_ctx); + VALKEYMODULE_NOT_USED(engine_ctx); + VALKEYMODULE_NOT_USED(type); + ldbDisable(); +} + +static void luaEngineDebuggerStart(ValkeyModuleCtx *module_ctx, + ValkeyModuleScriptingEngineCtx *engine_ctx, + ValkeyModuleScriptingEngineSubsystemType type, + ValkeyModuleString *source) { + VALKEYMODULE_NOT_USED(module_ctx); + VALKEYMODULE_NOT_USED(engine_ctx); + VALKEYMODULE_NOT_USED(type); + ldbStart(source); +} + +static void luaEngineDebuggerEnd(ValkeyModuleCtx *module_ctx, + ValkeyModuleScriptingEngineCtx *engine_ctx, + ValkeyModuleScriptingEngineSubsystemType type) { + VALKEYMODULE_NOT_USED(module_ctx); + VALKEYMODULE_NOT_USED(engine_ctx); + VALKEYMODULE_NOT_USED(type); + ldbEnd(); +} + +static struct luaEngineCtx *engine_ctx = NULL; + +int ValkeyModule_OnLoad(ValkeyModuleCtx *ctx, + ValkeyModuleString **argv, + int argc) { + VALKEYMODULE_NOT_USED(argv); + VALKEYMODULE_NOT_USED(argc); + + if (ValkeyModule_Init(ctx, "lua", 1, VALKEYMODULE_APIVER_1) == VALKEYMODULE_ERR) { + return VALKEYMODULE_ERR; + } + + ValkeyModule_SetModuleOptions(ctx, VALKEYMODULE_OPTIONS_HANDLE_REPL_ASYNC_LOAD | + VALKEYMODULE_OPTIONS_HANDLE_ATOMIC_SLOT_MIGRATION); + + engine_ctx = createEngineContext(ctx); + + if (ValkeyModule_LoadConfigs(ctx) == VALKEYMODULE_ERR) { + ValkeyModule_Log(ctx, "warning", "Failed to load LUA module configs"); + destroyEngineContext(engine_ctx); + engine_ctx = NULL; + return VALKEYMODULE_ERR; + } + + ValkeyModuleScriptingEngineMethods methods = { + .version = VALKEYMODULE_SCRIPTING_ENGINE_ABI_VERSION, + .compile_code = luaEngineCompileCode, + .free_function = luaEngineFreeFunction, + .call_function = luaEngineFunctionCall, + .get_function_memory_overhead = luaEngineFunctionMemoryOverhead, + .reset_env = luaEngineResetEnv, + .get_memory_info = luaEngineGetMemoryInfo, + .debugger_enable = luaEngineDebuggerEnable, + .debugger_disable = luaEngineDebuggerDisable, + .debugger_start = luaEngineDebuggerStart, + .debugger_end = luaEngineDebuggerEnd, + }; + + int result = ValkeyModule_RegisterScriptingEngine(ctx, + LUA_ENGINE_NAME, + engine_ctx, + &methods); + + if (result == VALKEYMODULE_ERR) { + ValkeyModule_Log(ctx, "warning", "Failed to register LUA scripting engine"); + destroyEngineContext(engine_ctx); + engine_ctx = NULL; + return VALKEYMODULE_ERR; + } + + engine_ctx->lua_enable_insecure_api = isLuaInsecureAPIEnabled(ctx); + + return VALKEYMODULE_OK; +} + +int ValkeyModule_OnUnload(ValkeyModuleCtx *ctx) { + if (ValkeyModule_UnregisterScriptingEngine(ctx, LUA_ENGINE_NAME) != VALKEYMODULE_OK) { + ValkeyModule_Log(ctx, "error", "Failed to unregister engine"); + return VALKEYMODULE_ERR; + } + + destroyEngineContext(engine_ctx); + engine_ctx = NULL; + + return VALKEYMODULE_OK; +} diff --git a/src/modules/lua/engine_structs.h b/src/modules/lua/engine_structs.h new file mode 100644 index 0000000000..9d10e53fdb --- /dev/null +++ b/src/modules/lua/engine_structs.h @@ -0,0 +1,25 @@ +#ifndef _ENGINE_STRUCTS_H_ +#define _ENGINE_STRUCTS_H_ + +#include +#include + +typedef struct luaEngineCtx { + lua_State *eval_lua; /* The Lua interpreter for EVAL commands. We use just one for all EVAL calls */ + lua_State *function_lua; /* The Lua interpreter for FCALL commands. We use just one for all FCALL calls */ + + char *redis_version; + uint32_t redis_version_num; + char *server_name; + char *valkey_version; + uint32_t valkey_version_num; + + int lua_enable_insecure_api; +} luaEngineCtx; + +typedef struct luaFunction { + lua_State *lua; /* Pointer to the lua context where this function was created. Only used in EVAL context. */ + int function_ref; /* Special ID that allows getting the Lua function object from the Lua registry */ +} luaFunction; + +#endif /* _ENGINE_STRUCTS_H_ */ diff --git a/src/lua/function_lua.c b/src/modules/lua/function_lua.c similarity index 71% rename from src/lua/function_lua.c rename to src/modules/lua/function_lua.c index 132d70c77d..3dd9b169d3 100644 --- a/src/lua/function_lua.c +++ b/src/modules/lua/function_lua.c @@ -41,12 +41,11 @@ #include "function_lua.h" #include "script_lua.h" +#include "list.h" -#include "../script.h" -#include "../adlist.h" -#include "../monotonic.h" -#include "../server.h" - +#include +#include +#include #include #include @@ -54,8 +53,28 @@ #define LIBRARY_API_NAME "__LIBRARY_API__" #define GLOBALS_API_NAME "__GLOBALS_API__" +typedef uint64_t monotime; + +static monotime getMonotonicUs(void) { + /* clock_gettime() is specified in POSIX.1b (1993). Even so, some systems + * did not support this until much later. CLOCK_MONOTONIC is technically + * optional and may not be supported - but it appears to be universal. + * If this is not supported, provide a system-specific alternate version. */ + struct timespec ts; + clock_gettime(CLOCK_MONOTONIC, &ts); + return ((uint64_t)ts.tv_sec) * 1000000 + ts.tv_nsec / 1000; +} + +static inline uint64_t elapsedUs(monotime start_time) { + return getMonotonicUs() - start_time; +} + +static inline uint64_t elapsedMs(monotime start_time) { + return elapsedUs(start_time) / 1000; +} + typedef struct loadCtx { - list *functions; + List *functions; monotime start_time; size_t timeout; } loadCtx; @@ -65,9 +84,9 @@ typedef struct loadCtx { * This execution should be fast and should only register * functions so 500ms should be more than enough. */ static void luaEngineLoadHook(lua_State *lua, lua_Debug *ar) { - UNUSED(ar); + VALKEYMODULE_NOT_USED(ar); loadCtx *load_ctx = luaGetFromRegistry(lua, REGISTRY_LOAD_CTX_NAME); - serverAssert(load_ctx); /* Only supported inside script invocation */ + ValkeyModule_Assert(load_ctx); /* Only supported inside script invocation */ uint64_t duration = elapsedMs(load_ctx->start_time); if (load_ctx->timeout > 0 && duration > load_ctx->timeout) { lua_sethook(lua, luaEngineLoadHook, LUA_MASKLINE, 0); @@ -78,13 +97,13 @@ static void luaEngineLoadHook(lua_State *lua, lua_Debug *ar) { } static void freeCompiledFunc(lua_State *lua, - compiledFunction *compiled_func) { - decrRefCount(compiled_func->name); + ValkeyModuleScriptingEngineCompiledFunction *compiled_func) { + ValkeyModule_FreeString(NULL, compiled_func->name); if (compiled_func->desc) { - decrRefCount(compiled_func->desc); + ValkeyModule_FreeString(NULL, compiled_func->desc); } luaFunctionFreeFunction(lua, compiled_func->function); - zfree(compiled_func); + ValkeyModule_Free(compiled_func); } /* @@ -98,12 +117,12 @@ static void freeCompiledFunc(lua_State *lua, * * Return NULL on compilation error and set the error to the err variable */ -compiledFunction **luaFunctionLibraryCreate(lua_State *lua, - const char *code, - size_t timeout, - size_t *out_num_compiled_functions, - robj **err) { - compiledFunction **compiled_functions = NULL; +ValkeyModuleScriptingEngineCompiledFunction **luaFunctionLibraryCreate(lua_State *lua, + const char *code, + size_t timeout, + size_t *out_num_compiled_functions, + ValkeyModuleString **err) { + ValkeyModuleScriptingEngineCompiledFunction **compiled_functions = NULL; /* set load library globals */ lua_getmetatable(lua, LUA_GLOBALSINDEX); @@ -115,15 +134,14 @@ compiledFunction **luaFunctionLibraryCreate(lua_State *lua, /* compile the code */ if (luaL_loadbuffer(lua, code, strlen(code), "@user_function")) { - sds error = sdscatfmt(sdsempty(), "Error compiling function: %s", lua_tostring(lua, -1)); - *err = createObject(OBJ_STRING, error); + *err = ValkeyModule_CreateStringPrintf(NULL, "Error compiling function: %s", lua_tostring(lua, -1)); lua_pop(lua, 1); /* pops the error */ goto done; } - serverAssert(lua_isfunction(lua, -1)); + ValkeyModule_Assert(lua_isfunction(lua, -1)); loadCtx load_ctx = { - .functions = listCreate(), + .functions = list_create(), .start_time = getMonotonicUs(), .timeout = timeout, }; @@ -134,34 +152,34 @@ compiledFunction **luaFunctionLibraryCreate(lua_State *lua, if (lua_pcall(lua, 0, 0, 0)) { errorInfo err_info = {0}; luaExtractErrorInformation(lua, &err_info); - sds error = sdscatfmt(sdsempty(), "Error registering functions: %s", err_info.msg); - *err = createObject(OBJ_STRING, error); + *err = ValkeyModule_CreateStringPrintf(NULL, "Error registering functions: %s", err_info.msg); lua_pop(lua, 1); /* pops the error */ luaErrorInformationDiscard(&err_info); - listIter *iter = listGetIterator(load_ctx.functions, AL_START_HEAD); - listNode *node = NULL; - while ((node = listNext(iter)) != NULL) { - freeCompiledFunc(lua, listNodeValue(node)); + ListIter *iter = list_get_iter(load_ctx.functions); + void *val = NULL; + while ((val = list_iter_next(iter)) != NULL) { + freeCompiledFunc(lua, val); } - listReleaseIterator(iter); - listRelease(load_ctx.functions); + list_release_iter(iter); + list_destroy(load_ctx.functions); goto done; } compiled_functions = - zcalloc(sizeof(compiledFunction *) * listLength(load_ctx.functions)); - listIter *iter = listGetIterator(load_ctx.functions, AL_START_HEAD); - listNode *node = NULL; + ValkeyModule_Calloc(list_length(load_ctx.functions), + sizeof(ValkeyModuleScriptingEngineCompiledFunction *)); + ListIter *iter = list_get_iter(load_ctx.functions); + void *val = NULL; *out_num_compiled_functions = 0; - while ((node = listNext(iter)) != NULL) { - compiledFunction *func = listNodeValue(node); + while ((val = list_iter_next(iter)) != NULL) { + ValkeyModuleScriptingEngineCompiledFunction *func = val; compiled_functions[*out_num_compiled_functions] = func; (*out_num_compiled_functions)++; } - listReleaseIterator(iter); - listRelease(load_ctx.functions); + list_release_iter(iter); + list_destroy(load_ctx.functions); done: /* restore original globals */ @@ -177,12 +195,12 @@ compiledFunction **luaFunctionLibraryCreate(lua_State *lua, return compiled_functions; } -static void luaRegisterFunctionArgsInitialize(compiledFunction *func, - robj *name, - robj *desc, +static void luaRegisterFunctionArgsInitialize(ValkeyModuleScriptingEngineCompiledFunction *func, + ValkeyModuleString *name, + ValkeyModuleString *desc, luaFunction *script, uint64_t flags) { - *func = (compiledFunction){ + *func = (ValkeyModuleScriptingEngineCompiledFunction){ .name = name, .desc = desc, .function = script, @@ -190,6 +208,20 @@ static void luaRegisterFunctionArgsInitialize(compiledFunction *func, }; } +typedef struct flagStr { + ValkeyModuleScriptingEngineScriptFlag flag; + const char *str; +} flagStr; + +flagStr scripts_flags_def[] = { + {.flag = VMSE_SCRIPT_FLAG_NO_WRITES, .str = "no-writes"}, + {.flag = VMSE_SCRIPT_FLAG_ALLOW_OOM, .str = "allow-oom"}, + {.flag = VMSE_SCRIPT_FLAG_ALLOW_STALE, .str = "allow-stale"}, + {.flag = VMSE_SCRIPT_FLAG_NO_CLUSTER, .str = "no-cluster"}, + {.flag = VMSE_SCRIPT_FLAG_ALLOW_CROSS_SLOT, .str = "allow-cross-slot-keys"}, + {.flag = 0, .str = NULL}, /* flags array end */ +}; + /* Read function flags located on the top of the Lua stack. * On success, return C_OK and set the flags to 'flags' out parameter * Return C_ERR if encounter an unknown flag. */ @@ -212,7 +244,7 @@ static int luaRegisterFunctionReadFlags(lua_State *lua, uint64_t *flags) { const char *flag_str = lua_tostring(lua, -1); int found = 0; - for (scriptFlag *flag = scripts_flags_def; flag->str; ++flag) { + for (flagStr *flag = scripts_flags_def; flag->str; ++flag) { if (!strcasecmp(flag->str, flag_str)) { f_flags |= flag->flag; found = 1; @@ -234,11 +266,23 @@ static int luaRegisterFunctionReadFlags(lua_State *lua, uint64_t *flags) { return ret; } +/* Return a Valkey string of the string value located on stack at the given index. + * Return NULL if the value is not a string. */ +static ValkeyModuleString *luaGetStringObject(lua_State *lua, int index) { + if (!lua_isstring(lua, index)) { + return NULL; + } + + size_t len; + const char *str = lua_tolstring(lua, index, &len); + return ValkeyModule_CreateString(NULL, str, len); +} + static int luaRegisterFunctionReadNamedArgs(lua_State *lua, - compiledFunction *func) { + ValkeyModuleScriptingEngineCompiledFunction *func) { char *err = NULL; - robj *name = NULL; - robj *desc = NULL; + ValkeyModuleString *name = NULL; + ValkeyModuleString *desc = NULL; luaFunction *script = NULL; uint64_t flags = 0; if (!lua_istable(lua, 1)) { @@ -274,7 +318,7 @@ static int luaRegisterFunctionReadNamedArgs(lua_State *lua, } int lua_function_ref = luaL_ref(lua, LUA_REGISTRYINDEX); - script = zmalloc(sizeof(*script)); + script = ValkeyModule_Alloc(sizeof(*script)); script->lua = lua; script->function_ref = lua_function_ref; continue; /* value was already popped, so no need to pop it out. */ @@ -314,21 +358,22 @@ static int luaRegisterFunctionReadNamedArgs(lua_State *lua, return C_OK; error: - if (name) decrRefCount(name); - if (desc) decrRefCount(desc); + if (name) ValkeyModule_FreeString(NULL, name); + if (desc) ValkeyModule_FreeString(NULL, desc); if (script) { lua_unref(lua, script->function_ref); - zfree(script); + ValkeyModule_Free(script); } luaPushError(lua, err); return C_ERR; } static int luaRegisterFunctionReadPositionalArgs(lua_State *lua, - compiledFunction *func) { + ValkeyModuleScriptingEngineCompiledFunction *func) { char *err = NULL; - robj *name = NULL; + ValkeyModuleString *name = NULL; luaFunction *script = NULL; + if (!(name = luaGetStringObject(lua, 1))) { err = "first argument to server.register_function must be a string"; goto error; @@ -341,7 +386,7 @@ static int luaRegisterFunctionReadPositionalArgs(lua_State *lua, int lua_function_ref = luaL_ref(lua, LUA_REGISTRYINDEX); - script = zmalloc(sizeof(*script)); + script = ValkeyModule_Alloc(sizeof(*script)); script->lua = lua; script->function_ref = lua_function_ref; @@ -350,12 +395,13 @@ static int luaRegisterFunctionReadPositionalArgs(lua_State *lua, return C_OK; error: - if (name) decrRefCount(name); + if (name) ValkeyModule_FreeString(NULL, name); luaPushError(lua, err); return C_ERR; } -static int luaRegisterFunctionReadArgs(lua_State *lua, compiledFunction *func) { +static int luaRegisterFunctionReadArgs(lua_State *lua, + ValkeyModuleScriptingEngineCompiledFunction *func) { int argc = lua_gettop(lua); if (argc < 1 || argc > 2) { luaPushError(lua, "wrong number of arguments to server.register_function"); @@ -376,19 +422,19 @@ static int luaFunctionRegisterFunction(lua_State *lua) { return luaError(lua); } - compiledFunction *func = zcalloc(sizeof(*func)); + ValkeyModuleScriptingEngineCompiledFunction *func = ValkeyModule_Calloc(1, sizeof(*func)); if (luaRegisterFunctionReadArgs(lua, func) != C_OK) { - zfree(func); + ValkeyModule_Free(func); return luaError(lua); } - listAddNodeTail(load_ctx->functions, func); + list_add(load_ctx->functions, func); return 0; } -void luaFunctionInitializeLuaState(lua_State *lua) { +void luaFunctionInitializeLuaState(luaEngineCtx *ctx, lua_State *lua) { /* Register the library commands table and fields and store it to registry */ lua_newtable(lua); /* load library globals */ lua_newtable(lua); /* load library `server` table */ @@ -398,7 +444,7 @@ void luaFunctionInitializeLuaState(lua_State *lua) { lua_settable(lua, -3); luaRegisterLogFunction(lua); - luaRegisterVersion(lua); + luaRegisterVersion(ctx, lua); luaSetErrorMetatable(lua); lua_setfield(lua, -2, SERVER_API_NAME); @@ -433,5 +479,5 @@ void luaFunctionInitializeLuaState(lua_State *lua) { void luaFunctionFreeFunction(lua_State *lua, void *function) { luaFunction *script = function; lua_unref(lua, script->function_ref); - zfree(function); + ValkeyModule_Free(function); } diff --git a/src/modules/lua/function_lua.h b/src/modules/lua/function_lua.h new file mode 100644 index 0000000000..791b4a789d --- /dev/null +++ b/src/modules/lua/function_lua.h @@ -0,0 +1,17 @@ +#ifndef _FUNCTION_LUA_H_ +#define _FUNCTION_LUA_H_ + +#include "../../valkeymodule.h" +#include "engine_structs.h" + +void luaFunctionInitializeLuaState(luaEngineCtx *ctx, lua_State *lua); + +ValkeyModuleScriptingEngineCompiledFunction **luaFunctionLibraryCreate(lua_State *lua, + const char *code, + size_t timeout, + size_t *out_num_compiled_functions, + ValkeyModuleString **err); + +void luaFunctionFreeFunction(lua_State *lua, void *function); + +#endif /* _FUNCTION_LUA_H_ */ diff --git a/src/modules/lua/list.c b/src/modules/lua/list.c new file mode 100644 index 0000000000..69e327a709 --- /dev/null +++ b/src/modules/lua/list.c @@ -0,0 +1,77 @@ +/* + * Copyright (c) Valkey Contributors + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +#include "list.h" +#include "../../valkeymodule.h" + +typedef struct ListNode { + void *val; + struct ListNode *next; +} ListNode; + +typedef struct List { + ListNode *head; + int length; +} List; + +typedef struct ListIter { + ListNode *current; +} ListIter; + +List *list_create(void) { + List *list = ValkeyModule_Alloc(sizeof(List)); + list->head = NULL; + list->length = 0; + return list; +} + +void list_destroy(List *list) { + ListNode *current = list->head; + while (current) { + ListNode *next = current->next; + ValkeyModule_Free(current); + current = next; + } + ValkeyModule_Free(list); +} + +int list_length(List *list) { + return list->length; +} + +void list_add(List *list, void *val) { + ListNode *new_node = ValkeyModule_Calloc(1, sizeof(ListNode)); + new_node->val = val; + new_node->next = NULL; + if (!list->head) { + list->head = new_node; + } else { + ListNode *current = list->head; + while (current->next) { + current = current->next; + } + current->next = new_node; + } + list->length++; +} + +ListIter *list_get_iter(List *list) { + ListIter *iter = ValkeyModule_Calloc(1, sizeof(ListIter)); + iter->current = list->head; + return iter; +} + +void *list_iter_next(ListIter *iter) { + if (!iter->current) { + return NULL; + } + void *val = iter->current->val; + iter->current = iter->current->next; + return val; +} + +void list_release_iter(ListIter *iter) { + ValkeyModule_Free(iter); +} diff --git a/src/modules/lua/list.h b/src/modules/lua/list.h new file mode 100644 index 0000000000..230be747ce --- /dev/null +++ b/src/modules/lua/list.h @@ -0,0 +1,16 @@ +#ifndef _LUA_LIST_H_ +#define _LUA_LIST_H_ + +typedef struct List List; +typedef struct ListIter ListIter; + +List *list_create(void); +void list_destroy(List *list); +void list_add(List *list, void *val); +int list_length(List *list); + +ListIter *list_get_iter(List *list); +void *list_iter_next(ListIter *iter); +void list_release_iter(ListIter *iter); + +#endif /* _LUA_LIST_H_ */ diff --git a/src/lua/script_lua.c b/src/modules/lua/script_lua.c similarity index 59% rename from src/lua/script_lua.c rename to src/modules/lua/script_lua.c index 5fafddd40a..b1fea7a641 100644 --- a/src/lua/script_lua.c +++ b/src/modules/lua/script_lua.c @@ -27,21 +27,46 @@ * POSSIBILITY OF SUCH DAMAGE. */ +#include "../../valkeymodule.h" #include "script_lua.h" #include "debug_lua.h" - -#include "../sha1.h" -#include "../rand.h" -#include "../cluster.h" -#include "../monotonic.h" -#include "../resp_parser.h" -#include "../version.h" +#include "engine_structs.h" +#include "../../sha1.h" +#include "../../rand.h" #include +#include #include #include -#include +#include +#include #include +#include +#include +#include + +#define LUA_CMD_OBJCACHE_SIZE 32 +#define LUA_CMD_OBJCACHE_MAX_LEN 64 + +/* Command propagation flags, see propagateNow() function */ +#define PROPAGATE_NONE 0 +#define PROPAGATE_AOF 1 +#define PROPAGATE_REPL 2 + +/* Log levels */ +#define LL_DEBUG 0 +#define LL_VERBOSE 1 +#define LL_NOTICE 2 +#define LL_WARNING 3 + +typedef struct luaFuncCallCtx { + ValkeyModuleCtx *module_ctx; + ValkeyModuleScriptingEngineServerRuntimeCtx *run_ctx; + ValkeyModuleScriptingEngineSubsystemType type; + int replication_flags; + int resp; + int lua_enable_insecure_api; +} luaFuncCallCtx; /* Globals that are added by the Lua libraries */ static char *libraries_allow_list[] = { @@ -136,32 +161,39 @@ static char *deny_list[] = { NULL, }; +static void _serverPanic(const char *file, int line, const char *msg, ...) { + fprintf(stderr, "------------------------------------------------"); + fprintf(stderr, "!!! Software Failure."); + fprintf(stderr, "Guru Meditation: %s #%s:%d", msg, file, line); + abort(); +} + +#define serverPanic(...) _serverPanic(__FILE__, __LINE__, __VA_ARGS__) + +typedef uint64_t monotime; + +monotime getMonotonicUs(void) { + /* clock_gettime() is specified in POSIX.1b (1993). Even so, some systems + * did not support this until much later. CLOCK_MONOTONIC is technically + * optional and may not be supported - but it appears to be universal. + * If this is not supported, provide a system-specific alternate version. */ + struct timespec ts; + clock_gettime(CLOCK_MONOTONIC, &ts); + return ((uint64_t)ts.tv_sec) * 1000000 + ts.tv_nsec / 1000; +} + +inline uint64_t elapsedUs(monotime start_time) { + return getMonotonicUs() - start_time; +} + +inline uint64_t elapsedMs(monotime start_time) { + return elapsedUs(start_time) / 1000; +} + static int server_math_random(lua_State *L); static int server_math_randomseed(lua_State *L); -static void redisProtocolToLuaType_Int(void *ctx, long long val, const char *proto, size_t proto_len); -static void -redisProtocolToLuaType_BulkString(void *ctx, const char *str, size_t len, const char *proto, size_t proto_len); -static void redisProtocolToLuaType_NullBulkString(void *ctx, const char *proto, size_t proto_len); -static void redisProtocolToLuaType_NullArray(void *ctx, const char *proto, size_t proto_len); -static void redisProtocolToLuaType_Status(void *ctx, const char *str, size_t len, const char *proto, size_t proto_len); -static void redisProtocolToLuaType_Error(void *ctx, const char *str, size_t len, const char *proto, size_t proto_len); -static void redisProtocolToLuaType_Array(struct ReplyParser *parser, void *ctx, size_t len, const char *proto); -static void redisProtocolToLuaType_Map(struct ReplyParser *parser, void *ctx, size_t len, const char *proto); -static void redisProtocolToLuaType_Set(struct ReplyParser *parser, void *ctx, size_t len, const char *proto); -static void redisProtocolToLuaType_Null(void *ctx, const char *proto, size_t proto_len); -static void redisProtocolToLuaType_Bool(void *ctx, int val, const char *proto, size_t proto_len); -static void redisProtocolToLuaType_Double(void *ctx, double d, const char *proto, size_t proto_len); -static void -redisProtocolToLuaType_BigNumber(void *ctx, const char *str, size_t len, const char *proto, size_t proto_len); -static void redisProtocolToLuaType_VerbatimString(void *ctx, - const char *format, - const char *str, - size_t len, - const char *proto, - size_t proto_len); -static void redisProtocolToLuaType_Attribute(struct ReplyParser *parser, void *ctx, size_t len, const char *proto); - -static void luaReplyToServerReply(client *c, client *script_client, lua_State *lua); + +static void luaReplyToServerReply(ValkeyModuleCtx *ctx, int resp_version, lua_State *lua); /* * Save the give pointer on Lua registry, used to save the Lua context and @@ -189,10 +221,10 @@ void *luaGetFromRegistry(lua_State *lua, const char *name) { return NULL; } /* must be light user data */ - serverAssert(lua_islightuserdata(lua, -1)); + ValkeyModule_Assert(lua_islightuserdata(lua, -1)); void *ptr = (void *)lua_topointer(lua, -1); - serverAssert(ptr); + ValkeyModule_Assert(ptr); /* pops the value */ lua_pop(lua, 1); @@ -200,419 +232,343 @@ void *luaGetFromRegistry(lua_State *lua, const char *name) { return ptr; } -/* --------------------------------------------------------------------------- - * Server reply to Lua type conversion functions. - * ------------------------------------------------------------------------- */ +char *lm_asprintf(char const *fmt, ...) { + va_list args; -/* Take a server reply in the RESP format and convert it into a - * Lua type. Thanks to this function, and the introduction of not connected - * clients, it is trivial to implement the server() lua function. - * - * Basically we take the arguments, execute the command in the context - * of a non connected client, then take the generated reply and convert it - * into a suitable Lua type. With this trick the scripting feature does not - * need the introduction of a full server internals API. The script - * is like a normal client that bypasses all the slow I/O paths. - * - * Note: in this function we do not do any sanity check as the reply is - * generated by the server directly. This allows us to go faster. - * - * Errors are returned as a table with a single 'err' field set to the - * error string. - */ + va_start(args, fmt); + size_t str_len = vsnprintf(NULL, 0, fmt, args) + 1; + va_end(args); -static const ReplyParserCallbacks DefaultLuaTypeParserCallbacks = { - .null_array_callback = redisProtocolToLuaType_NullArray, - .bulk_string_callback = redisProtocolToLuaType_BulkString, - .null_bulk_string_callback = redisProtocolToLuaType_NullBulkString, - .error_callback = redisProtocolToLuaType_Error, - .simple_str_callback = redisProtocolToLuaType_Status, - .long_callback = redisProtocolToLuaType_Int, - .array_callback = redisProtocolToLuaType_Array, - .set_callback = redisProtocolToLuaType_Set, - .map_callback = redisProtocolToLuaType_Map, - .bool_callback = redisProtocolToLuaType_Bool, - .double_callback = redisProtocolToLuaType_Double, - .null_callback = redisProtocolToLuaType_Null, - .big_number_callback = redisProtocolToLuaType_BigNumber, - .verbatim_string_callback = redisProtocolToLuaType_VerbatimString, - .attribute_callback = redisProtocolToLuaType_Attribute, - .error = NULL, -}; + char *str = ValkeyModule_Alloc(str_len); -static void redisProtocolToLuaType(lua_State *lua, char *reply) { - ReplyParser parser = {.curr_location = reply, .callbacks = DefaultLuaTypeParserCallbacks}; + va_start(args, fmt); + vsnprintf(str, str_len, fmt, args); + va_end(args); - parseReply(&parser, lua); + return str; } -static void redisProtocolToLuaType_Int(void *ctx, long long val, const char *proto, size_t proto_len) { - UNUSED(proto); - UNUSED(proto_len); - if (!ctx) { - return; - } - - lua_State *lua = ctx; - if (!lua_checkstack(lua, 1)) { - /* Increase the Lua stack if needed, to make sure there is enough room - * to push elements to the stack. On failure, exit with panic. */ - serverPanic("lua stack limit reach when parsing server.call reply"); - } - lua_pushnumber(lua, (lua_Number)val); +char *lm_strcpy(const char *str) { + size_t len = strlen(str); + char *res = ValkeyModule_Alloc(len + 1); + memcpy(res, str, len + 1); + return res; } -static void redisProtocolToLuaType_NullBulkString(void *ctx, const char *proto, size_t proto_len) { - UNUSED(proto); - UNUSED(proto_len); - if (!ctx) { - return; - } - - lua_State *lua = ctx; - if (!lua_checkstack(lua, 1)) { - /* Increase the Lua stack if needed, to make sure there is enough room - * to push elements to the stack. On failure, exit with panic. */ - serverPanic("lua stack limit reach when parsing server.call reply"); - } - lua_pushboolean(lua, 0); -} +char *lm_strtrim(char *s, const char *cset) { + char *end, *sp, *ep; + size_t len; -static void redisProtocolToLuaType_NullArray(void *ctx, const char *proto, size_t proto_len) { - UNUSED(proto); - UNUSED(proto_len); - if (!ctx) { - return; - } - lua_State *lua = ctx; - if (!lua_checkstack(lua, 1)) { - /* Increase the Lua stack if needed, to make sure there is enough room - * to push elements to the stack. On failure, exit with panic. */ - serverPanic("lua stack limit reach when parsing server.call reply"); - } - lua_pushboolean(lua, 0); + sp = s; + ep = end = s + strlen(s) - 1; + while (sp <= end && strchr(cset, *sp)) sp++; + while (ep > sp && strchr(cset, *ep)) ep--; + len = (ep - sp) + 1; + if (s != sp) memmove(s, sp, len); + s[len] = '\0'; + return s; } +/* This function is used in order to push an error on the Lua stack in the + * format used by server.pcall to return errors, which is a lua table + * with an "err" field set to the error string including the error code. + * Note that this table is never a valid reply by proper commands, + * since the returned tables are otherwise always indexed by integers, never by strings. + * + * The function takes ownership on the given err_buffer. */ +static void luaPushErrorBuff(lua_State *lua, const char *err_buffer) { + char *msg; -static void -redisProtocolToLuaType_BulkString(void *ctx, const char *str, size_t len, const char *proto, size_t proto_len) { - UNUSED(proto); - UNUSED(proto_len); - if (!ctx) { - return; - } - - lua_State *lua = ctx; - if (!lua_checkstack(lua, 1)) { - /* Increase the Lua stack if needed, to make sure there is enough room - * to push elements to the stack. On failure, exit with panic. */ - serverPanic("lua stack limit reach when parsing server.call reply"); + /* If debugging is active and in step mode, log errors resulting from + * server commands. */ + if (ldbIsEnabled()) { + char *msg = lm_asprintf(" %s", err_buffer); + ldbLogCString(msg); + ValkeyModule_Free(msg); } - lua_pushlstring(lua, str, len); -} -static void redisProtocolToLuaType_Status(void *ctx, const char *str, size_t len, const char *proto, size_t proto_len) { - UNUSED(proto); - UNUSED(proto_len); - if (!ctx) { - return; + char *final_msg = NULL; + /* There are two possible formats for the received `error` string: + * 1) "-CODE msg": in this case we remove the leading '-' since we don't store it as part of the lua error format. + * 2) "msg": in this case we prepend a generic 'ERR' code since all error statuses need some error code. + * We support format (1) so this function can reuse the error messages used in other places. + * We support format (2) so it'll be easy to pass descriptive errors to this function without worrying about format. + */ + if (err_buffer[0] == '-') { + /* derive error code from the message */ + char *err_msg = strstr(err_buffer, " "); + if (!err_msg) { + msg = lm_strcpy(err_buffer + 1); + final_msg = lm_asprintf("ERR %s", msg); + } else { + *err_msg = '\0'; + msg = lm_strcpy(err_msg + 1); + msg = lm_strtrim(msg, "\r\n"); + final_msg = lm_asprintf("%s %s", err_buffer + 1, msg); + } + } else { + msg = lm_strcpy(err_buffer); + msg = lm_strtrim(msg, "\r\n"); + final_msg = lm_asprintf("%s", msg); } + /* Trim newline at end of string. If we reuse the ready-made error objects (case 1 above) then we might + * have a newline that needs to be trimmed. In any case the lua server error table shouldn't end with a newline. */ - lua_State *lua = ctx; - if (!lua_checkstack(lua, 3)) { - /* Increase the Lua stack if needed, to make sure there is enough room - * to push elements to the stack. On failure, exit with panic. */ - serverPanic("lua stack limit reach when parsing server.call reply"); - } lua_newtable(lua); - lua_pushstring(lua, "ok"); - lua_pushlstring(lua, str, len); + lua_pushstring(lua, "err"); + lua_pushstring(lua, final_msg); lua_settable(lua, -3); + + ValkeyModule_Free(msg); + ValkeyModule_Free(final_msg); } -static void redisProtocolToLuaType_Error(void *ctx, const char *str, size_t len, const char *proto, size_t proto_len) { - UNUSED(proto); - UNUSED(proto_len); - if (!ctx) { - return; - } +void luaPushError(lua_State *lua, const char *error) { + luaPushErrorBuff(lua, error); +} - lua_State *lua = ctx; - if (!lua_checkstack(lua, 3)) { - /* Increase the Lua stack if needed, to make sure there is enough room - * to push elements to the stack. On failure, exit with panic. */ - serverPanic("lua stack limit reach when parsing server.call reply"); - } - sds err_msg = sdscatlen(sdsnew("-"), str, len); - luaPushErrorBuff(lua, err_msg); - /* push a field indicate to ignore updating the stats on this error - * because it was already updated when executing the command. */ - lua_pushstring(lua, "ignore_error_stats_update"); - lua_pushboolean(lua, 1); - lua_settable(lua, -3); +/* In case the error set into the Lua stack by luaPushError() was generated + * by the non-error-trapping version of server.pcall(), which is server.call(), + * this function will raise the Lua error so that the execution of the + * script will be halted. */ +int luaError(lua_State *lua) { + return lua_error(lua); } -static void redisProtocolToLuaType_Map(struct ReplyParser *parser, void *ctx, size_t len, const char *proto) { - UNUSED(proto); - lua_State *lua = ctx; - if (lua) { - if (!lua_checkstack(lua, 3)) { +/* --------------------------------------------------------------------------- + * Server reply to Lua type conversion functions. + * ------------------------------------------------------------------------- */ + +static void callReplyToLuaType(lua_State *lua, ValkeyModuleCallReply *reply, int resp) { + int type = ValkeyModule_CallReplyType(reply); + switch (type) { + case VALKEYMODULE_REPLY_STRING: { + if (!lua_checkstack(lua, 1)) { /* Increase the Lua stack if needed, to make sure there is enough room * to push elements to the stack. On failure, exit with panic. */ serverPanic("lua stack limit reach when parsing server.call reply"); } - lua_newtable(lua); - lua_pushstring(lua, "map"); - lua_createtable(lua, 0, len); - } - for (size_t j = 0; j < len; j++) { - parseReply(parser, lua); - parseReply(parser, lua); - if (lua) lua_settable(lua, -3); + size_t len = 0; + const char *str = ValkeyModule_CallReplyStringPtr(reply, &len); + lua_pushlstring(lua, str, len); + break; } - if (lua) lua_settable(lua, -3); -} - -static void redisProtocolToLuaType_Set(struct ReplyParser *parser, void *ctx, size_t len, const char *proto) { - UNUSED(proto); - - lua_State *lua = ctx; - if (lua) { + case VALKEYMODULE_REPLY_SIMPLE_STRING: { if (!lua_checkstack(lua, 3)) { /* Increase the Lua stack if needed, to make sure there is enough room * to push elements to the stack. On failure, exit with panic. */ serverPanic("lua stack limit reach when parsing server.call reply"); } + size_t len = 0; + const char *str = ValkeyModule_CallReplyStringPtr(reply, &len); lua_newtable(lua); - lua_pushstring(lua, "set"); - lua_createtable(lua, 0, len); - } - for (size_t j = 0; j < len; j++) { - parseReply(parser, lua); - if (lua) { - if (!lua_checkstack(lua, 1)) { - /* Increase the Lua stack if needed, to make sure there is enough room - * to push elements to the stack. On failure, exit with panic. - * Notice that here we need to check the stack again because the recursive - * call to redisProtocolToLuaType might have use the room allocated in the stack*/ - serverPanic("lua stack limit reach when parsing server.call reply"); - } - lua_pushboolean(lua, 1); - lua_settable(lua, -3); + lua_pushstring(lua, "ok"); + lua_pushlstring(lua, str, len); + lua_settable(lua, -3); + break; + } + case VALKEYMODULE_REPLY_INTEGER: { + if (!lua_checkstack(lua, 1)) { + /* Increase the Lua stack if needed, to make sure there is enough room + * to push elements to the stack. On failure, exit with panic. */ + serverPanic("lua stack limit reach when parsing server.call reply"); } + long long val = ValkeyModule_CallReplyInteger(reply); + lua_pushnumber(lua, (lua_Number)val); + break; } - if (lua) lua_settable(lua, -3); -} - -static void redisProtocolToLuaType_Array(struct ReplyParser *parser, void *ctx, size_t len, const char *proto) { - UNUSED(proto); - - lua_State *lua = ctx; - if (lua) { + case VALKEYMODULE_REPLY_ARRAY: { if (!lua_checkstack(lua, 2)) { /* Increase the Lua stack if needed, to make sure there is enough room * to push elements to the stack. On failure, exit with panic. */ serverPanic("lua stack limit reach when parsing server.call reply"); } - lua_createtable(lua, len, 0); - } - for (size_t j = 0; j < len; j++) { - if (lua) lua_pushnumber(lua, j + 1); - parseReply(parser, lua); - if (lua) lua_settable(lua, -3); - } -} + size_t items = ValkeyModule_CallReplyLength(reply); + lua_createtable(lua, items, 0); -static void redisProtocolToLuaType_Attribute(struct ReplyParser *parser, void *ctx, size_t len, const char *proto) { - UNUSED(proto); + for (size_t i = 0; i < items; i++) { + ValkeyModuleCallReply *val = ValkeyModule_CallReplyArrayElement(reply, i); - /* Parse the attribute reply. - * Currently, we do not expose the attribute to the Lua script so - * we just need to continue parsing and ignore it (the NULL ensures that the - * reply will be ignored). */ - for (size_t j = 0; j < len; j++) { - parseReply(parser, NULL); - parseReply(parser, NULL); + lua_pushnumber(lua, i + 1); + callReplyToLuaType(lua, val, resp); + lua_settable(lua, -3); + } + break; } + case VALKEYMODULE_REPLY_NULL: + case VALKEYMODULE_REPLY_ARRAY_NULL: + if (!lua_checkstack(lua, 1)) { + /* Increase the Lua stack if needed, to make sure there is enough room + * to push elements to the stack. On failure, exit with panic. */ + serverPanic("lua stack limit reach when parsing server.call reply"); + } + if (resp == 2) { + lua_pushboolean(lua, 0); + } else { + lua_pushnil(lua); + } + break; + case VALKEYMODULE_REPLY_MAP: { + if (!lua_checkstack(lua, 3)) { + /* Increase the Lua stack if needed, to make sure there is enough room + * to push elements to the stack. On failure, exit with panic. */ + serverPanic("lua stack limit reach when parsing server.call reply"); + } - /* Parse the reply itself. */ - parseReply(parser, ctx); -} + size_t items = ValkeyModule_CallReplyLength(reply); + lua_newtable(lua); + lua_pushstring(lua, "map"); + lua_createtable(lua, 0, items); -static void redisProtocolToLuaType_VerbatimString(void *ctx, - const char *format, - const char *str, - size_t len, - const char *proto, - size_t proto_len) { - UNUSED(proto); - UNUSED(proto_len); - if (!ctx) { - return; - } + for (size_t i = 0; i < items; i++) { + ValkeyModuleCallReply *key = NULL; + ValkeyModuleCallReply *val = NULL; + ValkeyModule_CallReplyMapElement(reply, i, &key, &val); - lua_State *lua = ctx; - if (!lua_checkstack(lua, 5)) { - /* Increase the Lua stack if needed, to make sure there is enough room - * to push elements to the stack. On failure, exit with panic. */ - serverPanic("lua stack limit reach when parsing server.call reply"); + callReplyToLuaType(lua, key, resp); + callReplyToLuaType(lua, val, resp); + lua_settable(lua, -3); + } + lua_settable(lua, -3); + break; } - lua_newtable(lua); - lua_pushstring(lua, "verbatim_string"); - lua_newtable(lua); - lua_pushstring(lua, "string"); - lua_pushlstring(lua, str, len); - lua_settable(lua, -3); - lua_pushstring(lua, "format"); - lua_pushlstring(lua, format, 3); - lua_settable(lua, -3); - lua_settable(lua, -3); -} + case VALKEYMODULE_REPLY_SET: { + if (!lua_checkstack(lua, 3)) { + /* Increase the Lua stack if needed, to make sure there is enough room + * to push elements to the stack. On failure, exit with panic. */ + serverPanic("lua stack limit reach when parsing server.call reply"); + } -static void -redisProtocolToLuaType_BigNumber(void *ctx, const char *str, size_t len, const char *proto, size_t proto_len) { - UNUSED(proto); - UNUSED(proto_len); - if (!ctx) { - return; - } + size_t items = ValkeyModule_CallReplyLength(reply); + lua_newtable(lua); + lua_pushstring(lua, "set"); + lua_createtable(lua, 0, items); - lua_State *lua = ctx; - if (!lua_checkstack(lua, 3)) { - /* Increase the Lua stack if needed, to make sure there is enough room - * to push elements to the stack. On failure, exit with panic. */ - serverPanic("lua stack limit reach when parsing server.call reply"); - } - lua_newtable(lua); - lua_pushstring(lua, "big_number"); - lua_pushlstring(lua, str, len); - lua_settable(lua, -3); -} + for (size_t i = 0; i < items; i++) { + ValkeyModuleCallReply *val = ValkeyModule_CallReplySetElement(reply, i); -static void redisProtocolToLuaType_Null(void *ctx, const char *proto, size_t proto_len) { - UNUSED(proto); - UNUSED(proto_len); - if (!ctx) { - return; + callReplyToLuaType(lua, val, resp); + lua_pushboolean(lua, 1); + lua_settable(lua, -3); + } + lua_settable(lua, -3); + break; } - - lua_State *lua = ctx; - if (!lua_checkstack(lua, 1)) { - /* Increase the Lua stack if needed, to make sure there is enough room - * to push elements to the stack. On failure, exit with panic. */ - serverPanic("lua stack limit reach when parsing server.call reply"); + case VALKEYMODULE_REPLY_BOOL: { + if (!lua_checkstack(lua, 1)) { + /* Increase the Lua stack if needed, to make sure there is enough room + * to push elements to the stack. On failure, exit with panic. */ + serverPanic("lua stack limit reach when parsing server.call reply"); + } + int b = ValkeyModule_CallReplyBool(reply); + lua_pushboolean(lua, b); + break; } - lua_pushnil(lua); -} - -static void redisProtocolToLuaType_Bool(void *ctx, int val, const char *proto, size_t proto_len) { - UNUSED(proto); - UNUSED(proto_len); - if (!ctx) { - return; + case VALKEYMODULE_REPLY_DOUBLE: { + if (!lua_checkstack(lua, 3)) { + /* Increase the Lua stack if needed, to make sure there is enough room + * to push elements to the stack. On failure, exit with panic. */ + serverPanic("lua stack limit reach when parsing server.call reply"); + } + double d = ValkeyModule_CallReplyDouble(reply); + lua_newtable(lua); + lua_pushstring(lua, "double"); + lua_pushnumber(lua, d); + lua_settable(lua, -3); + break; } - lua_State *lua = ctx; - if (!lua_checkstack(lua, 1)) { - /* Increase the Lua stack if needed, to make sure there is enough room - * to push elements to the stack. On failure, exit with panic. */ - serverPanic("lua stack limit reach when parsing server.call reply"); + case VALKEYMODULE_REPLY_BIG_NUMBER: { + if (!lua_checkstack(lua, 3)) { + /* Increase the Lua stack if needed, to make sure there is enough room + * to push elements to the stack. On failure, exit with panic. */ + serverPanic("lua stack limit reach when parsing server.call reply"); + } + size_t len = 0; + const char *str = ValkeyModule_CallReplyBigNumber(reply, &len); + lua_newtable(lua); + lua_pushstring(lua, "big_number"); + lua_pushlstring(lua, str, len); + lua_settable(lua, -3); + break; } - lua_pushboolean(lua, val); -} - -static void redisProtocolToLuaType_Double(void *ctx, double d, const char *proto, size_t proto_len) { - UNUSED(proto); - UNUSED(proto_len); - if (!ctx) { - return; + case VALKEYMODULE_REPLY_VERBATIM_STRING: { + if (!lua_checkstack(lua, 5)) { + /* Increase the Lua stack if needed, to make sure there is enough room + * to push elements to the stack. On failure, exit with panic. */ + serverPanic("lua stack limit reach when parsing server.call reply"); + } + size_t len = 0; + const char *format = NULL; + const char *str = ValkeyModule_CallReplyVerbatim(reply, &len, &format); + lua_newtable(lua); + lua_pushstring(lua, "verbatim_string"); + lua_newtable(lua); + lua_pushstring(lua, "string"); + lua_pushlstring(lua, str, len); + lua_settable(lua, -3); + lua_pushstring(lua, "format"); + lua_pushlstring(lua, format, 3); + lua_settable(lua, -3); + lua_settable(lua, -3); + break; } - - lua_State *lua = ctx; - if (!lua_checkstack(lua, 3)) { - /* Increase the Lua stack if needed, to make sure there is enough room - * to push elements to the stack. On failure, exit with panic. */ - serverPanic("lua stack limit reach when parsing server.call reply"); + case VALKEYMODULE_REPLY_ERROR: { + if (!lua_checkstack(lua, 3)) { + /* Increase the Lua stack if needed, to make sure there is enough room + * to push elements to the stack. On failure, exit with panic. */ + serverPanic("lua stack limit reach when parsing server.call reply"); + } + const char *err = ValkeyModule_CallReplyStringPtr(reply, NULL); + luaPushErrorBuff(lua, err); + /* push a field indicate to ignore updating the stats on this error + * because it was already updated when executing the command. */ + lua_pushstring(lua, "ignore_error_stats_update"); + lua_pushboolean(lua, 1); + lua_settable(lua, -3); + break; + } + case VALKEYMODULE_REPLY_ATTRIBUTE: { + /* Currently, we do not expose the attribute to the Lua script. */ + break; + } + case VALKEYMODULE_REPLY_PROMISE: + case VALKEYMODULE_REPLY_UNKNOWN: + default: + ValkeyModule_Assert(0); } - lua_newtable(lua); - lua_pushstring(lua, "double"); - lua_pushnumber(lua, d); - lua_settable(lua, -3); } -/* This function is used in order to push an error on the Lua stack in the - * format used by server.pcall to return errors, which is a lua table - * with an "err" field set to the error string including the error code. - * Note that this table is never a valid reply by proper commands, - * since the returned tables are otherwise always indexed by integers, never by strings. - * - * The function takes ownership on the given err_buffer. */ -void luaPushErrorBuff(lua_State *lua, sds err_buffer) { - sds msg; - sds error_code; +/* --------------------------------------------------------------------------- + * Lua reply to server reply conversion functions. + * ------------------------------------------------------------------------- */ - /* If debugging is active and in step mode, log errors resulting from - * server commands. */ - if (ldbIsEnabled()) { - ldbLog(sdscatprintf(sdsempty(), " %s", err_buffer)); - } +char *strmapchars(char *s, const char *from, const char *to, size_t setlen) { + size_t j, i, l = strlen(s); - /* There are two possible formats for the received `error` string: - * 1) "-CODE msg": in this case we remove the leading '-' since we don't store it as part of the lua error format. - * 2) "msg": in this case we prepend a generic 'ERR' code since all error statuses need some error code. - * We support format (1) so this function can reuse the error messages used in other places. - * We support format (2) so it'll be easy to pass descriptive errors to this function without worrying about format. - */ - if (err_buffer[0] == '-') { - /* derive error code from the message */ - char *err_msg = strstr(err_buffer, " "); - if (!err_msg) { - msg = sdsnew(err_buffer + 1); - error_code = sdsnew("ERR"); - } else { - *err_msg = '\0'; - msg = sdsnew(err_msg + 1); - error_code = sdsnew(err_buffer + 1); + for (j = 0; j < l; j++) { + for (i = 0; i < setlen; i++) { + if (s[j] == from[i]) { + s[j] = to[i]; + break; + } } - sdsfree(err_buffer); - } else { - msg = err_buffer; - error_code = sdsnew("ERR"); } - /* Trim newline at end of string. If we reuse the ready-made error objects (case 1 above) then we might - * have a newline that needs to be trimmed. In any case the lua server error table shouldn't end with a newline. */ - msg = sdstrim(msg, "\r\n"); - sds final_msg = sdscatfmt(error_code, " %s", msg); - - lua_newtable(lua); - lua_pushstring(lua, "err"); - lua_pushstring(lua, final_msg); - lua_settable(lua, -3); - - sdsfree(msg); - sdsfree(final_msg); -} - -void luaPushError(lua_State *lua, const char *error) { - luaPushErrorBuff(lua, sdsnew(error)); + return s; } -/* In case the error set into the Lua stack by luaPushError() was generated - * by the non-error-trapping version of server.pcall(), which is server.call(), - * this function will raise the Lua error so that the execution of the - * script will be halted. */ -int luaError(lua_State *lua) { - return lua_error(lua); +char *copy_string_from_lua_stack(lua_State *lua) { + const char *str = lua_tostring(lua, -1); + size_t len = lua_strlen(lua, -1); + char *res = ValkeyModule_Alloc(len + 1); + strncpy(res, str, len); + res[len] = 0; + return res; } - -/* --------------------------------------------------------------------------- - * Lua reply to server reply conversion functions. - * ------------------------------------------------------------------------- */ - /* Reply to client 'c' converting the top element in the Lua stack to a * server reply. As a side effect the element is consumed from the stack. */ -static void luaReplyToServerReply(client *c, client *script_client, lua_State *lua) { +static void luaReplyToServerReply(ValkeyModuleCtx *ctx, int resp_version, lua_State *lua) { int t = lua_type(lua, -1); if (!lua_checkstack(lua, 4)) { @@ -620,20 +576,28 @@ static void luaReplyToServerReply(client *c, client *script_client, lua_State *l * to push 4 elements to the stack. On failure, return error. * Notice that we need, in the worst case, 4 elements because returning a map might * require push 4 elements to the Lua stack.*/ - addReplyError(c, "reached lua stack limit"); + ValkeyModule_ReplyWithError(ctx, "ERR reached lua stack limit"); lua_pop(lua, 1); /* pop the element from the stack */ return; } switch (t) { - case LUA_TSTRING: addReplyBulkCBuffer(c, (char *)lua_tostring(lua, -1), lua_strlen(lua, -1)); break; + case LUA_TSTRING: + ValkeyModule_ReplyWithStringBuffer(ctx, lua_tostring(lua, -1), lua_strlen(lua, -1)); + break; case LUA_TBOOLEAN: - if (script_client->resp == 2) - addReply(c, lua_toboolean(lua, -1) ? shared.cone : shared.null[c->resp]); - else - addReplyBool(c, lua_toboolean(lua, -1)); + if (resp_version == 2) { + int b = lua_toboolean(lua, -1); + if (b) { + ValkeyModule_ReplyWithLongLong(ctx, 1); + } else { + ValkeyModule_ReplyWithNull(ctx); + } + } else { + ValkeyModule_ReplyWithBool(ctx, lua_toboolean(lua, -1)); + } break; - case LUA_TNUMBER: addReplyLongLong(c, (long long)lua_tonumber(lua, -1)); break; + case LUA_TNUMBER: ValkeyModule_ReplyWithLongLong(ctx, (long long)lua_tonumber(lua, -1)); break; case LUA_TTABLE: /* We need to check if it is an array, an error, or a status reply. * Error are returned as a single element table with 'err' field. @@ -650,9 +614,7 @@ static void luaReplyToServerReply(client *c, client *script_client, lua_State *l 1); /* pop the error message, we will use luaExtractErrorInformation to get error information */ errorInfo err_info = {0}; luaExtractErrorInformation(lua, &err_info); - addReplyErrorFormatEx( - c, ERR_REPLY_FLAG_CUSTOM | (err_info.ignore_err_stats_update ? ERR_REPLY_FLAG_NO_STATS_UPDATE : 0), - "-%s", err_info.msg); + ValkeyModule_ReplyWithCustomErrorFormat(ctx, !err_info.ignore_err_stats_update, "%s", err_info.msg); luaErrorInformationDiscard(&err_info); lua_pop(lua, 1); /* pop the result table */ return; @@ -664,10 +626,10 @@ static void luaReplyToServerReply(client *c, client *script_client, lua_State *l lua_rawget(lua, -2); t = lua_type(lua, -1); if (t == LUA_TSTRING) { - sds ok = sdsnew(lua_tostring(lua, -1)); - sdsmapchars(ok, "\r\n", " ", 2); - addReplyStatusLength(c, ok, sdslen(ok)); - sdsfree(ok); + char *ok = copy_string_from_lua_stack(lua); + strmapchars(ok, "\r\n", " ", 2); + ValkeyModule_ReplyWithSimpleString(ctx, ok); + ValkeyModule_Free(ok); lua_pop(lua, 2); return; } @@ -678,7 +640,7 @@ static void luaReplyToServerReply(client *c, client *script_client, lua_State *l lua_rawget(lua, -2); t = lua_type(lua, -1); if (t == LUA_TNUMBER) { - addReplyDouble(c, lua_tonumber(lua, -1)); + ValkeyModule_ReplyWithDouble(ctx, lua_tonumber(lua, -1)); lua_pop(lua, 2); return; } @@ -689,10 +651,10 @@ static void luaReplyToServerReply(client *c, client *script_client, lua_State *l lua_rawget(lua, -2); t = lua_type(lua, -1); if (t == LUA_TSTRING) { - sds big_num = sdsnewlen(lua_tostring(lua, -1), lua_strlen(lua, -1)); - sdsmapchars(big_num, "\r\n", " ", 2); - addReplyBigNum(c, big_num, sdslen(big_num)); - sdsfree(big_num); + char *big_num = copy_string_from_lua_stack(lua); + strmapchars(big_num, "\r\n", " ", 2); + ValkeyModule_ReplyWithBigNumber(ctx, big_num, strlen(big_num)); + ValkeyModule_Free(big_num); lua_pop(lua, 2); return; } @@ -714,7 +676,7 @@ static void luaReplyToServerReply(client *c, client *script_client, lua_State *l if (t == LUA_TSTRING) { size_t len; char *str = (char *)lua_tolstring(lua, -1, &len); - addReplyVerbatim(c, str, len, format); + ValkeyModule_ReplyWithVerbatimStringType(ctx, str, len, format); lua_pop(lua, 4); return; } @@ -730,18 +692,18 @@ static void luaReplyToServerReply(client *c, client *script_client, lua_State *l t = lua_type(lua, -1); if (t == LUA_TTABLE) { int maplen = 0; - void *replylen = addReplyDeferredLen(c); + ValkeyModule_ReplyWithMap(ctx, VALKEYMODULE_POSTPONED_LEN); /* we took care of the stack size on function start */ lua_pushnil(lua); /* Use nil to start iteration. */ while (lua_next(lua, -2)) { /* Stack now: table, key, value */ - lua_pushvalue(lua, -2); /* Dup key before consuming. */ - luaReplyToServerReply(c, script_client, lua); /* Return key. */ - luaReplyToServerReply(c, script_client, lua); /* Return value. */ + lua_pushvalue(lua, -2); /* Dup key before consuming. */ + luaReplyToServerReply(ctx, resp_version, lua); /* Return key. */ + luaReplyToServerReply(ctx, resp_version, lua); /* Return value. */ /* Stack now: table, key. */ maplen++; } - setDeferredMapLen(c, replylen, maplen); + ValkeyModule_ReplySetMapLength(ctx, maplen); lua_pop(lua, 2); return; } @@ -753,25 +715,25 @@ static void luaReplyToServerReply(client *c, client *script_client, lua_State *l t = lua_type(lua, -1); if (t == LUA_TTABLE) { int setlen = 0; - void *replylen = addReplyDeferredLen(c); + ValkeyModule_ReplyWithSet(ctx, VALKEYMODULE_POSTPONED_LEN); /* we took care of the stack size on function start */ lua_pushnil(lua); /* Use nil to start iteration. */ while (lua_next(lua, -2)) { /* Stack now: table, key, true */ - lua_pop(lua, 1); /* Discard the boolean value. */ - lua_pushvalue(lua, -1); /* Dup key before consuming. */ - luaReplyToServerReply(c, script_client, lua); /* Return key. */ + lua_pop(lua, 1); /* Discard the boolean value. */ + lua_pushvalue(lua, -1); /* Dup key before consuming. */ + luaReplyToServerReply(ctx, resp_version, lua); /* Return key. */ /* Stack now: table, key. */ setlen++; } - setDeferredSetLen(c, replylen, setlen); + ValkeyModule_ReplySetSetLength(ctx, setlen); lua_pop(lua, 2); return; } lua_pop(lua, 1); /* Discard field name pushed before. */ /* Handle the array reply. */ - void *replylen = addReplyDeferredLen(c); + ValkeyModule_ReplyWithArray(ctx, VALKEYMODULE_POSTPONED_LEN); int j = 1, mbulklen = 0; while (1) { /* we took care of the stack size on function start */ @@ -782,12 +744,12 @@ static void luaReplyToServerReply(client *c, client *script_client, lua_State *l lua_pop(lua, 1); break; } - luaReplyToServerReply(c, script_client, lua); + luaReplyToServerReply(ctx, resp_version, lua); mbulklen++; } - setDeferredArrayLen(c, replylen, mbulklen); + ValkeyModule_ReplySetArrayLength(ctx, mbulklen); break; - default: addReplyNull(c); + default: ValkeyModule_ReplyWithNull(ctx); } lua_pop(lua, 1); } @@ -795,17 +757,144 @@ static void luaReplyToServerReply(client *c, client *script_client, lua_State *l /* --------------------------------------------------------------------------- * Lua server.* functions implementations. * ------------------------------------------------------------------------- */ -void freeLuaServerArgv(robj **argv, int argc, int argv_len); +void freeLuaServerArgv(ValkeyModuleCtx *ctx, ValkeyModuleString **argv, int argc); + +/* Return the number of digits of 'v' when converted to string in radix 10. + * See ll2string() for more information. */ +static uint32_t digits10(uint64_t v) { + if (v < 10) return 1; + if (v < 100) return 2; + if (v < 1000) return 3; + if (v < 1000000000000UL) { + if (v < 100000000UL) { + if (v < 1000000) { + if (v < 10000) return 4; + return 5 + (v >= 100000); + } + return 7 + (v >= 10000000UL); + } + if (v < 10000000000UL) { + return 9 + (v >= 1000000000UL); + } + return 11 + (v >= 100000000000UL); + } + return 12 + digits10(v / 1000000000000UL); +} + +/* Convert a unsigned long long into a string. Returns the number of + * characters needed to represent the number. + * If the buffer is not big enough to store the string, 0 is returned. + * + * Based on the following article (that apparently does not provide a + * novel approach but only publicizes an already used technique): + * + * https://www.facebook.com/notes/facebook-engineering/three-optimization-tips-for-c/10151361643253920 */ +static int ull2string(char *dst, size_t dstlen, unsigned long long value) { + static const char digits[201] = "0001020304050607080910111213141516171819" + "2021222324252627282930313233343536373839" + "4041424344454647484950515253545556575859" + "6061626364656667686970717273747576777879" + "8081828384858687888990919293949596979899"; + + /* Check length. */ + uint32_t length = digits10(value); + if (length >= dstlen) goto err; + ; + + /* Null term. */ + uint32_t next = length - 1; + dst[next + 1] = '\0'; + while (value >= 100) { + int const i = (value % 100) * 2; + value /= 100; + dst[next] = digits[i + 1]; + dst[next - 1] = digits[i]; + next -= 2; + } + + /* Handle last 1-2 digits. */ + if (value < 10) { + dst[next] = '0' + (uint32_t)value; + } else { + int i = (uint32_t)value * 2; + dst[next] = digits[i + 1]; + dst[next - 1] = digits[i]; + } + return length; +err: + /* force add Null termination */ + if (dstlen > 0) dst[0] = '\0'; + return 0; +} -/* Cached argv array across calls. */ -static robj **lua_argv = NULL; -static int lua_argv_size = 0; +/* Convert a long long into a string. Returns the number of + * characters needed to represent the number. + * If the buffer is not big enough to store the string, 0 is returned. */ +static int ll2string(char *dst, size_t dstlen, long long svalue) { + unsigned long long value; + int negative = 0; + + /* The ull2string function with 64bit unsigned integers for simplicity, so + * we convert the number here and remember if it is negative. */ + if (svalue < 0) { + if (svalue != LLONG_MIN) { + value = -svalue; + } else { + value = ((unsigned long long)LLONG_MAX) + 1; + } + if (dstlen < 2) goto err; + negative = 1; + dst[0] = '-'; + dst++; + dstlen--; + } else { + value = svalue; + } -/* Cache of recently used small arguments to avoid malloc calls. */ -static robj *lua_args_cached_objects[LUA_CMD_OBJCACHE_SIZE]; -static size_t lua_args_cached_objects_len[LUA_CMD_OBJCACHE_SIZE]; + /* Converts the unsigned long long value to string*/ + int length = ull2string(dst, dstlen, value); + if (length == 0) return 0; + return length + negative; -static robj **luaArgsToServerArgv(lua_State *lua, int *argc, int *argv_len) { +err: + /* force add Null termination */ + if (dstlen > 0) dst[0] = '\0'; + return 0; +} + +/* Returns 1 if the double value can safely be represented in long long without + * precision loss, in which case the corresponding long long is stored in the out variable. */ +static int double2ll(double d, long long *out) { +#if (__DBL_MANT_DIG__ >= 52) && (__DBL_MANT_DIG__ <= 63) && (LLONG_MAX == 0x7fffffffffffffffLL) + /* Check if the float is in a safe range to be casted into a + * long long. We are assuming that long long is 64 bit here. + * Also we are assuming that there are no implementations around where + * double has precision < 52 bit. + * + * Under this assumptions we test if a double is inside a range + * where casting to long long is safe. Then using two castings we + * make sure the decimal part is zero. If all this is true we can use + * integer without precision loss. + * + * Note that numbers above 2^52 and below 2^63 use all the fraction bits as real part, + * and the exponent bits are positive, which means the "decimal" part must be 0. + * i.e. all double values in that range are representable as a long without precision loss, + * but not all long values in that range can be represented as a double. + * we only care about the first part here. */ + if (d < (double)(-LLONG_MAX / 2) || d > (double)(LLONG_MAX / 2)) return 0; + long long ll = d; + if (ll == d) { + *out = ll; + return 1; + } +#else + VALKEYMODULE_NOT_USED(d); + VALKEYMODULE_NOT_USED(out); +#endif + return 0; +} + +static ValkeyModuleString **luaArgsToServerArgv(ValkeyModuleCtx *ctx, lua_State *lua, int *argc) { int j; /* Require at least one argument */ *argc = lua_gettop(lua); @@ -814,12 +903,7 @@ static robj **luaArgsToServerArgv(lua_State *lua, int *argc, int *argv_len) { return NULL; } - /* Build the arguments vector (reuse a cached argv from last call) */ - if (lua_argv_size < *argc) { - lua_argv = zrealloc(lua_argv, sizeof(robj *) * *argc); - lua_argv_size = *argc; - } - *argv_len = lua_argv_size; + ValkeyModuleString **lua_argv = ValkeyModule_Alloc(sizeof(ValkeyModuleString *) * *argc); for (j = 0; j < *argc; j++) { char *obj_s; @@ -835,9 +919,9 @@ static robj **luaArgsToServerArgv(lua_State *lua, int *argc, int *argv_len) { * to convert it as an integer when that's possible, since the string could later be used * in a context that doesn't support scientific notation (e.g. 1e9 instead of 100000000). */ long long lvalue; - if (double2ll((double)num, &lvalue)) + if (double2ll((double)num, &lvalue)) { obj_len = ll2string(dbuf, sizeof(dbuf), lvalue); - else { + } else { obj_len = fpconv_dtoa((double)num, dbuf); dbuf[obj_len] = '\0'; } @@ -846,16 +930,8 @@ static robj **luaArgsToServerArgv(lua_State *lua, int *argc, int *argv_len) { obj_s = (char *)lua_tolstring(lua, j + 1, &obj_len); if (obj_s == NULL) break; /* Not a string. */ } - /* Try to use a cached object. */ - if (j < LUA_CMD_OBJCACHE_SIZE && lua_args_cached_objects[j] && lua_args_cached_objects_len[j] >= obj_len) { - sds s = lua_args_cached_objects[j]->ptr; - lua_argv[j] = lua_args_cached_objects[j]; - lua_args_cached_objects[j] = NULL; - memcpy(s, obj_s, obj_len + 1); - sdssetlen(s, obj_len); - } else { - lua_argv[j] = createStringObject(obj_s, obj_len); - } + + lua_argv[j] = ValkeyModule_CreateString(ctx, obj_s, obj_len); } /* Pop all arguments from the stack, we do not need them anymore @@ -866,51 +942,76 @@ static robj **luaArgsToServerArgv(lua_State *lua, int *argc, int *argv_len) { * is not a string or an integer (lua_isstring() return true for * integers as well). */ if (j != *argc) { - freeLuaServerArgv(lua_argv, j, lua_argv_size); - luaPushError(lua, "Command arguments must be strings or integers"); + freeLuaServerArgv(ctx, lua_argv, j); + luaPushError(lua, "ERR Command arguments must be strings or integers"); return NULL; } return lua_argv; } -void freeLuaServerArgv(robj **argv, int argc, int argv_len) { +void freeLuaServerArgv(ValkeyModuleCtx *ctx, ValkeyModuleString **argv, int argc) { int j; for (j = 0; j < argc; j++) { - robj *o = argv[j]; - - /* Try to cache the object in the lua_args_cached_objects array. - * The object must be small, SDS-encoded, and with refcount = 1 - * (we must be the only owner) for us to cache it. */ - if (j < LUA_CMD_OBJCACHE_SIZE && o->refcount == 1 && - (o->encoding == OBJ_ENCODING_RAW || o->encoding == OBJ_ENCODING_EMBSTR) && - sdslen(o->ptr) <= LUA_CMD_OBJCACHE_MAX_LEN) { - sds s = o->ptr; - if (lua_args_cached_objects[j]) decrRefCount(lua_args_cached_objects[j]); - lua_args_cached_objects[j] = o; - lua_args_cached_objects_len[j] = sdsalloc(s); - } else { - decrRefCount(o); + ValkeyModuleString *o = argv[j]; + ValkeyModule_FreeString(ctx, o); + } + ValkeyModule_Free(argv); +} + +static void luaProcessReplyError(ValkeyModuleCallReply *reply, lua_State *lua) { + const char *err = ValkeyModule_CallReplyStringPtr(reply, NULL); + int push_error = 1; + + /* The following error messages rewrites are required to keep the backward compatibility + * with the previous Lua engine that was implemented in Valkey core. */ + if (errno == ESPIPE) { + if (strncmp(err, "ERR command ", strlen("ERR command ")) == 0) { + luaPushError(lua, "ERR This Valkey command is not allowed from script"); + push_error = 0; + } + } else if (errno == EINVAL) { + if (strncmp(err, "ERR wrong number of arguments for ", strlen("ERR wrong number of arguments for ")) == 0) { + luaPushError(lua, "ERR Wrong number of args calling command from script"); + push_error = 0; + } + } else if (errno == ENOENT) { + if (strncmp(err, "ERR unknown command '", strlen("ERR unknown command '")) == 0) { + luaPushError(lua, "ERR Unknown command called from script"); + push_error = 0; + } + } else if (errno == EACCES) { + if (strncmp(err, "NOPERM ", strlen("NOPERM ")) == 0) { + const char *err_prefix = "ERR ACL failure in script: "; + size_t err_len = strlen(err_prefix) + strlen(err + strlen("NOPERM ")) + 1; + char *err_msg = ValkeyModule_Alloc(err_len * sizeof(char)); + bzero(err_msg, err_len); + strcpy(err_msg, err_prefix); + strcat(err_msg, err + strlen("NOPERM ")); + luaPushError(lua, err_msg); + ValkeyModule_Free(err_msg); + push_error = 0; } } - if (argv != lua_argv || argv_len != lua_argv_size) { - /* The command changed argv, scrap the cache and start over. */ - zfree(argv); - lua_argv = NULL; - lua_argv_size = 0; + + if (push_error) { + luaPushError(lua, err); } + /* push a field indicate to ignore updating the stats on this error + * because it was already updated when executing the command. */ + lua_pushstring(lua, "ignore_error_stats_update"); + lua_pushboolean(lua, 1); + lua_settable(lua, -3); } static int luaServerGenericCommand(lua_State *lua, int raise_error) { - int j; - scriptRunCtx *rctx = luaGetFromRegistry(lua, REGISTRY_RUN_CTX_NAME); - serverAssert(rctx); /* Only supported inside script invocation */ - sds err = NULL; - client *c = rctx->c; - sds reply; - - c->argv = luaArgsToServerArgv(lua, &c->argc, &c->argv_len); - if (c->argv == NULL) { + luaFuncCallCtx *rctx = luaGetFromRegistry(lua, REGISTRY_RUN_CTX_NAME); + ValkeyModule_Assert(rctx); /* Only supported inside script invocation */ + ValkeyModuleCallReply *reply; + + int argc = 0; + ValkeyModuleString **argv = luaArgsToServerArgv(rctx->module_ctx, lua, &argc); + if (argv == NULL) { return raise_error ? luaError(lua) : 1; } @@ -923,7 +1024,7 @@ static int luaServerGenericCommand(lua_State *lua, int raise_error) { if (inuse) { char *recursion_warning = "luaRedisGenericCommand() recursive call detected. " "Are you doing funny stuff with Lua debug hooks?"; - serverLog(LL_WARNING, "%s", recursion_warning); + ValkeyModule_Log(rctx->module_ctx, "warning", "%s", recursion_warning); luaPushError(lua, recursion_warning); return 1; } @@ -931,69 +1032,81 @@ static int luaServerGenericCommand(lua_State *lua, int raise_error) { /* Log the command if debugging is active. */ if (ldbIsEnabled()) { - sds cmdlog = sdsnew(""); - for (j = 0; j < c->argc; j++) { - if (j == 10) { - cmdlog = sdscatprintf(cmdlog, " ... (%d more)", c->argc - j - 1); + const char *cmd_prefix = ""; + char *cmdlog = ValkeyModule_Calloc(strlen(cmd_prefix) + 1, sizeof(char)); + strcpy(cmdlog, cmd_prefix); + for (int i = 0; i < argc; i++) { + if (i == 10) { + char *new_cmdlog = lm_asprintf("%s ... (%d more)", cmdlog, argc - i - 1); + ValkeyModule_Free(cmdlog); + cmdlog = new_cmdlog; break; } else { - cmdlog = sdscatlen(cmdlog, " ", 1); - cmdlog = sdscatsds(cmdlog, c->argv[j]->ptr); + const char *argv_cstr = ValkeyModule_StringPtrLen(argv[i], NULL); + char *new_cmdlog = lm_asprintf("%s %s", cmdlog, argv_cstr); + ValkeyModule_Free(cmdlog); + cmdlog = new_cmdlog; } } - ldbLog(cmdlog); + ldbLogCString(cmdlog); + ValkeyModule_Free(cmdlog); } - scriptCall(rctx, &err); - if (err) { - luaPushError(lua, err); - sdsfree(err); - /* push a field indicate to ignore updating the stats on this error - * because it was already updated when executing the command. */ - lua_pushstring(lua, "ignore_error_stats_update"); - lua_pushboolean(lua, 1); - lua_settable(lua, -3); - goto cleanup; + char fmt[13] = "v!EMSX"; + int fmt_idx = 6; /* Index of the last char in fmt[] */ + + ValkeyModuleString *username = ValkeyModule_GetCurrentUserName(rctx->module_ctx); + if (username != NULL) { + fmt[fmt_idx++] = 'C'; + ValkeyModule_FreeString(rctx->module_ctx, username); } - /* Convert the result of the command into a suitable Lua type. - * The first thing we need is to create a single string from the client - * output buffers. */ - if (listLength(c->reply) == 0 && (size_t)c->bufpos < c->buf_usable_size) { - /* This is a fast path for the common case of a reply inside the - * client static buffer. Don't create an SDS string but just use - * the client buffer directly. */ - c->buf[c->bufpos] = '\0'; - reply = c->buf; - c->bufpos = 0; - } else { - reply = sdsnewlen(c->buf, c->bufpos); - c->bufpos = 0; - while (listLength(c->reply)) { - clientReplyBlock *o = listNodeValue(listFirst(c->reply)); + if (!(rctx->replication_flags & PROPAGATE_AOF)) { + fmt[fmt_idx++] = 'A'; + } + if (!(rctx->replication_flags & PROPAGATE_REPL)) { + fmt[fmt_idx++] = 'R'; + } + if (!rctx->replication_flags) { + /* PROPAGATE_NONE case */ + fmt[fmt_idx++] = 'A'; + fmt[fmt_idx++] = 'R'; + } + if (rctx->resp == 3) { + fmt[fmt_idx++] = '3'; + } + fmt[fmt_idx] = '\0'; - reply = sdscatlen(reply, o->buf, o->used); - listDelNode(c->reply, listFirst(c->reply)); - } + const char *cmdname = ValkeyModule_StringPtrLen(argv[0], NULL); + + errno = 0; + reply = ValkeyModule_Call(rctx->module_ctx, cmdname, fmt, argv + 1, argc - 1); + freeLuaServerArgv(rctx->module_ctx, argv, argc); + int reply_type = ValkeyModule_CallReplyType(reply); + if (errno != 0) { + ValkeyModule_Assert(reply_type == VALKEYMODULE_REPLY_ERROR); + + const char *err = ValkeyModule_CallReplyStringPtr(reply, NULL); + ValkeyModule_Log(rctx->module_ctx, "debug", "command returned an error: %s errno=%d", err, errno); + + luaProcessReplyError(reply, lua); + goto cleanup; + } else if (raise_error && reply_type != VALKEYMODULE_REPLY_ERROR) { + raise_error = 0; } - if (raise_error && reply[0] != '-') raise_error = 0; - redisProtocolToLuaType(lua, reply); - /* If the debugger is active, log the reply from the server. */ - if (ldbIsEnabled()) - ldbLogRespReply(reply); + callReplyToLuaType(lua, reply, rctx->resp); - if (reply != c->buf) sdsfree(reply); - c->reply_bytes = 0; + /* If the debugger is active, log the reply from the server. */ + if (ldbIsEnabled()) { + ValkeyModule_ScriptingEngineDebuggerLogRespReply(reply); + } cleanup: /* Clean up. Command code may have changed argv/argc so we use the * argv/argc of the client instead of the local variables. */ - freeLuaServerArgv(c->argv, c->argc, c->argv_len); - c->argc = c->argv_len = 0; - c->user = NULL; - c->argv = NULL; - resetClient(c); + ValkeyModule_FreeCallReply(reply); + inuse--; if (raise_error) { @@ -1043,6 +1156,29 @@ static int luaRedisPCallCommand(lua_State *lua) { return luaServerGenericCommand(lua, 0); } +/* Perform the SHA1 of the input string. We use this both for hashing script + * bodies in order to obtain the Lua function name, and in the implementation + * of server.sha1(). + * + * 'digest' should point to a 41 bytes buffer: 40 for SHA1 converted into an + * hexadecimal number, plus 1 byte for null term. */ +void sha1hex(char *digest, char *script, size_t len) { + SHA1_CTX ctx; + unsigned char hash[20]; + char *cset = "0123456789abcdef"; + int j; + + SHA1Init(&ctx); + SHA1Update(&ctx, (unsigned char *)script, len); + SHA1Final(hash, &ctx); + + for (j = 0; j < 20; j++) { + digest[j * 2] = cset[((hash[j] & 0xF0) >> 4)]; + digest[j * 2 + 1] = cset[(hash[j] & 0xF)]; + } + digest[40] = '\0'; +} + /* This adds server.sha1hex(string) to Lua scripts using the same hashing * function used for sha1ing lua scripts. */ static int luaRedisSha1hexCommand(lua_State *lua) { @@ -1091,13 +1227,14 @@ static int luaRedisErrorReplyCommand(lua_State *lua) { /* add '-' if not exists */ const char *err = lua_tostring(lua, -1); - sds err_buff = NULL; + char *err_buff = NULL; if (err[0] != '-') { - err_buff = sdscatfmt(sdsempty(), "-%s", err); + err_buff = lm_asprintf("-%s", err); } else { - err_buff = sdsnew(err); + err_buff = lm_strcpy(err); } luaPushErrorBuff(lua, err_buff); + ValkeyModule_Free(err_buff); return 1; } @@ -1113,8 +1250,8 @@ static int luaRedisStatusReplyCommand(lua_State *lua) { static int luaRedisSetReplCommand(lua_State *lua) { int flags, argc = lua_gettop(lua); - scriptRunCtx *rctx = luaGetFromRegistry(lua, REGISTRY_RUN_CTX_NAME); - serverAssert(rctx); /* Only supported inside script invocation */ + luaFuncCallCtx *rctx = luaGetFromRegistry(lua, REGISTRY_RUN_CTX_NAME); + ValkeyModule_Assert(rctx); /* Only supported inside script invocation */ if (argc != 1) { luaPushError(lua, "server.set_repl() requires one argument."); @@ -1127,7 +1264,8 @@ static int luaRedisSetReplCommand(lua_State *lua) { return luaError(lua); } - scriptSetRepl(rctx, flags); + rctx->replication_flags = flags; + return 0; } @@ -1135,31 +1273,35 @@ static int luaRedisSetReplCommand(lua_State *lua) { * * Checks ACL permissions for given command for the current user. */ static int luaRedisAclCheckCmdPermissionsCommand(lua_State *lua) { - scriptRunCtx *rctx = luaGetFromRegistry(lua, REGISTRY_RUN_CTX_NAME); - serverAssert(rctx); /* Only supported inside script invocation */ + luaFuncCallCtx *rctx = luaGetFromRegistry(lua, REGISTRY_RUN_CTX_NAME); + ValkeyModule_Assert(rctx); /* Only supported inside script invocation */ + int raise_error = 0; - int argc, argv_len; - robj **argv = luaArgsToServerArgv(lua, &argc, &argv_len); + int argc = 0; + ValkeyModuleString **argv = luaArgsToServerArgv(rctx->module_ctx, lua, &argc); /* Require at least one argument */ if (argv == NULL) return luaError(lua); - /* Find command */ - struct serverCommand *cmd; - if ((cmd = lookupCommand(argv, argc)) == NULL) { - luaPushError(lua, "Invalid command passed to server.acl_check_cmd()"); - raise_error = 1; - } else { - int keyidxptr; - if (ACLCheckAllUserCommandPerm(rctx->original_client->user, cmd, argv, argc, &keyidxptr) != ACL_OK) { - lua_pushboolean(lua, 0); + ValkeyModuleString *username = ValkeyModule_GetCurrentUserName(rctx->module_ctx); + ValkeyModuleUser *user = ValkeyModule_GetModuleUserFromUserName(username); + ValkeyModule_FreeString(rctx->module_ctx, username); + + if (ValkeyModule_ACLCheckCommandPermissions(user, argv, argc) != VALKEYMODULE_OK) { + if (errno == ENOENT) { + luaPushError(lua, "ERR Invalid command passed to server.acl_check_cmd()"); + raise_error = 1; } else { - lua_pushboolean(lua, 1); + ValkeyModule_Assert(errno == EACCES); + lua_pushboolean(lua, 0); } + } else { + lua_pushboolean(lua, 1); } - freeLuaServerArgv(argv, argc, argv_len); + ValkeyModule_FreeModuleUser(user); + freeLuaServerArgv(rctx->module_ctx, argv, argc); if (raise_error) return luaError(lua); else @@ -1169,9 +1311,11 @@ static int luaRedisAclCheckCmdPermissionsCommand(lua_State *lua) { /* server.log() */ static int luaLogCommand(lua_State *lua) { + luaFuncCallCtx *rctx = luaGetFromRegistry(lua, REGISTRY_RUN_CTX_NAME); + ValkeyModule_Assert(rctx); /* Only supported inside script invocation */ + int j, argc = lua_gettop(lua); int level; - sds log; if (argc < 2) { luaPushError(lua, "server.log() requires two arguments or more."); @@ -1185,29 +1329,44 @@ static int luaLogCommand(lua_State *lua) { luaPushError(lua, "Invalid log level."); return luaError(lua); } - if (level < server.verbosity) return 0; /* Glue together all the arguments */ - log = sdsempty(); + char *log = NULL; for (j = 1; j < argc; j++) { size_t len; char *s; s = (char *)lua_tolstring(lua, (-argc) + j, &len); if (s) { - if (j != 1) log = sdscatlen(log, " ", 1); - log = sdscatlen(log, s, len); + if (j != 1) { + char *next_log = lm_asprintf("%s %s", log, s); + ValkeyModule_Free(log); + log = next_log; + } else { + log = lm_asprintf("%s", s); + } } } - serverLogRaw(level, log); - sdsfree(log); + + const char *level_str = NULL; + switch (level) { + case LL_DEBUG: level_str = "debug"; break; + case LL_VERBOSE: level_str = "verbose"; break; + case LL_NOTICE: level_str = "notice"; break; + case LL_WARNING: level_str = "warning"; break; + default: ValkeyModule_Assert(0); + } + + ValkeyModule_Log(rctx->module_ctx, level_str, "%s", log); + ValkeyModule_Free(log); return 0; } /* server.setresp() */ static int luaSetResp(lua_State *lua) { - scriptRunCtx *rctx = luaGetFromRegistry(lua, REGISTRY_RUN_CTX_NAME); - serverAssert(rctx); /* Only supported inside script invocation */ + luaFuncCallCtx *rctx = luaGetFromRegistry(lua, REGISTRY_RUN_CTX_NAME); + ValkeyModule_Assert(rctx); /* Only supported inside script invocation */ + int argc = lua_gettop(lua); if (argc != 1) { @@ -1220,7 +1379,9 @@ static int luaSetResp(lua_State *lua) { luaPushError(lua, "RESP version must be 2 or 3."); return luaError(lua); } - scriptSetResp(rctx, resp); + + rctx->resp = resp; + return 0; } @@ -1256,23 +1417,11 @@ static void luaLoadLibraries(lua_State *lua) { #endif } -/* Return sds of the string value located on stack at the given index. - * Return NULL if the value is not a string. */ -robj *luaGetStringObject(lua_State *lua, int index) { - if (!lua_isstring(lua, index)) { - return NULL; - } - - size_t len; - const char *str = lua_tolstring(lua, index, &len); - robj *str_obj = createStringObject(str, len); - return str_obj; -} - static int luaProtectedTableError(lua_State *lua) { + luaFuncCallCtx *rctx = luaGetFromRegistry(lua, REGISTRY_RUN_CTX_NAME); int argc = lua_gettop(lua); if (argc != 2) { - serverLog(LL_WARNING, "malicious code trying to call luaProtectedTableError with wrong arguments"); + ValkeyModule_Log(rctx->module_ctx, "warning", "malicious code trying to call luaProtectedTableError with wrong arguments"); luaL_error(lua, "Wrong number of arguments to luaProtectedTableError"); } if (!lua_isstring(lua, -1) && !lua_isnumber(lua, -1)) { @@ -1298,9 +1447,10 @@ void luaSetErrorMetatable(lua_State *lua) { } static int luaNewIndexAllowList(lua_State *lua) { + luaFuncCallCtx *rctx = luaGetFromRegistry(lua, REGISTRY_RUN_CTX_NAME); int argc = lua_gettop(lua); if (argc != 3) { - serverLog(LL_WARNING, "malicious code trying to call luaNewIndexAllowList with wrong arguments"); + ValkeyModule_Log(rctx->module_ctx, "warning", "malicious code trying to call luaNewIndexAllowList with wrong arguments"); luaL_error(lua, "Wrong number of arguments to luaNewIndexAllowList"); } if (!lua_istable(lua, -3)) { @@ -1333,7 +1483,7 @@ static int luaNewIndexAllowList(lua_State *lua) { for (; *c; ++c) { if (strcmp(*c, variable_name) == 0) { deprecated = 1; - allowed = server.lua_enable_insecure_api ? 1 : 0; + allowed = rctx->lua_enable_insecure_api ? 1 : 0; break; } } @@ -1348,10 +1498,10 @@ static int luaNewIndexAllowList(lua_State *lua) { } } if (!*c && !deprecated) { - serverLog(LL_WARNING, - "A key '%s' was added to Lua globals which is neither on the globals allow list nor listed on the " - "deny list.", - variable_name); + ValkeyModule_Log(rctx->module_ctx, "warning", + "A key '%s' was added to Lua globals which is neither on the globals allow list nor listed on the " + "deny list.", + variable_name); } } else { lua_rawset(lua, -3); @@ -1365,7 +1515,7 @@ static int luaNewIndexAllowList(lua_State *lua) { * The metatable is set on the table which located on the top * of the stack. */ -void luaSetAllowListProtection(lua_State *lua) { +static void luaSetAllowListProtection(lua_State *lua) { lua_newtable(lua); /* push metatable */ lua_pushcfunction(lua, luaNewIndexAllowList); /* push get error handler */ lua_setfield(lua, -2, "__newindex"); @@ -1432,27 +1582,27 @@ void luaSetTableProtectionForBasicTypes(lua_State *lua) { } } -void luaRegisterVersion(lua_State *lua) { +void luaRegisterVersion(luaEngineCtx *ctx, lua_State *lua) { /* For legacy compatibility reasons include Redis versions. */ lua_pushstring(lua, "REDIS_VERSION_NUM"); - lua_pushnumber(lua, REDIS_VERSION_NUM); + lua_pushnumber(lua, ctx->redis_version_num); lua_settable(lua, -3); lua_pushstring(lua, "REDIS_VERSION"); - lua_pushstring(lua, REDIS_VERSION); + lua_pushstring(lua, ctx->redis_version); lua_settable(lua, -3); /* Now push the Valkey version information. */ lua_pushstring(lua, "VALKEY_VERSION_NUM"); - lua_pushnumber(lua, VALKEY_VERSION_NUM); + lua_pushnumber(lua, ctx->valkey_version_num); lua_settable(lua, -3); lua_pushstring(lua, "VALKEY_VERSION"); - lua_pushstring(lua, VALKEY_VERSION); + lua_pushstring(lua, ctx->valkey_version); lua_settable(lua, -3); lua_pushstring(lua, "SERVER_NAME"); - lua_pushstring(lua, SERVER_NAME); + lua_pushstring(lua, ctx->server_name); lua_settable(lua, -3); } @@ -1484,7 +1634,7 @@ void luaRegisterLogFunction(lua_State *lua) { * This function only handles fields common between Functions and LUA scripting. * scriptingInit() and functionsInit() may add additional fields specific to each. */ -void luaRegisterServerAPI(lua_State *lua) { +void luaRegisterServerAPI(luaEngineCtx *ctx, lua_State *lua) { /* In addition to registering server.call/pcall API, we will throw a custom message when a script accesses * undefined global variable. LUA stores global variables in the global table, accessible to us on stack at virtual * index = LUA_GLOBALSINDEX. We will set __index handler in global table's metatable to a custom C function to @@ -1497,9 +1647,16 @@ void luaRegisterServerAPI(lua_State *lua) { luaSetAllowListProtection(lua); lua_pop(lua, 1); + luaFuncCallCtx call_ctx = { + .lua_enable_insecure_api = ctx->lua_enable_insecure_api, + }; + luaSaveOnRegistry(lua, REGISTRY_RUN_CTX_NAME, &call_ctx); + /* Add default C functions provided in deps/lua codebase to handle basic data types such as table, string etc. */ luaLoadLibraries(lua); + luaSaveOnRegistry(lua, REGISTRY_RUN_CTX_NAME, NULL); + /* Before Redis OSS 7, Lua used to return error messages as strings from pcall function. With Valkey (or Redis OSS 7), Lua now returns * error messages as tables. To keep backwards compatibility, we wrap the Lua pcall function with our own * implementation of C function that converts table to string. */ @@ -1520,7 +1677,7 @@ void luaRegisterServerAPI(lua_State *lua) { luaRegisterLogFunction(lua); /* Add SERVER_VERSION_NUM, SERVER_VERSION and SERVER_NAME fields with appropriate values. */ - luaRegisterVersion(lua); + luaRegisterVersion(ctx, lua); /* Add server.setresp function to allow LUA scripts to change the RESP version for server.call and server.pcall * invocations. */ @@ -1594,12 +1751,14 @@ void luaRegisterServerAPI(lua_State *lua) { /* Set an array of String Objects as a Lua array (table) stored into a * global variable. */ -static void luaCreateArray(lua_State *lua, robj **elev, int elec) { +static void luaCreateArray(lua_State *lua, ValkeyModuleString **elev, int elec) { int j; lua_createtable(lua, elec, 0); for (j = 0; j < elec; j++) { - lua_pushlstring(lua, (char *)elev[j]->ptr, sdslen(elev[j]->ptr)); + size_t len = 0; + const char *str = ValkeyModule_StringPtrLen(elev[j], &len); + lua_pushlstring(lua, str, len); lua_rawseti(lua, -2, j + 1); } } @@ -1649,17 +1808,20 @@ static int server_math_randomseed(lua_State *L) { /* This is the Lua script "count" hook that we use to detect scripts timeout. */ static void luaMaskCountHook(lua_State *lua, lua_Debug *ar) { - UNUSED(ar); - scriptRunCtx *rctx = luaGetFromRegistry(lua, REGISTRY_RUN_CTX_NAME); - serverAssert(rctx); /* Only supported inside script invocation */ - if (scriptInterrupt(rctx) == SCRIPT_KILL) { + VALKEYMODULE_NOT_USED(ar); + + luaFuncCallCtx *rctx = luaGetFromRegistry(lua, REGISTRY_RUN_CTX_NAME); + ValkeyModule_Assert(rctx); /* Only supported inside script invocation */ + + ValkeyModuleScriptingEngineExecutionState state = ValkeyModule_GetFunctionExecutionState(rctx->run_ctx); + if (state == VMSE_STATE_KILLED) { char *err = NULL; - if (rctx->flags & SCRIPT_EVAL_MODE) { - err = "Script killed by user with SCRIPT KILL."; + if (rctx->type == VMSE_EVAL) { + err = "ERR Script killed by user with SCRIPT KILL."; } else { - err = "Script killed by user with FUNCTION KILL."; + err = "ERR Script killed by user with FUNCTION KILL."; } - serverLog(LL_NOTICE, "%s", err); + ValkeyModule_Log(NULL, "notice", "%s", err); /* * Set the hook to invoke all the time so the user @@ -1674,14 +1836,14 @@ static void luaMaskCountHook(lua_State *lua, lua_Debug *ar) { } void luaErrorInformationDiscard(errorInfo *err_info) { - if (err_info->msg) sdsfree(err_info->msg); - if (err_info->source) sdsfree(err_info->source); - if (err_info->line) sdsfree(err_info->line); + if (err_info->msg) ValkeyModule_Free(err_info->msg); + if (err_info->source) ValkeyModule_Free(err_info->source); + if (err_info->line) ValkeyModule_Free(err_info->line); } void luaExtractErrorInformation(lua_State *lua, errorInfo *err_info) { if (lua_isstring(lua, -1)) { - err_info->msg = sdscatfmt(sdsempty(), "ERR %s", lua_tostring(lua, -1)); + err_info->msg = lm_asprintf("ERR %s", lua_tostring(lua, -1)); err_info->line = NULL; err_info->source = NULL; err_info->ignore_err_stats_update = 0; @@ -1690,19 +1852,19 @@ void luaExtractErrorInformation(lua_State *lua, errorInfo *err_info) { lua_getfield(lua, -1, "err"); if (lua_isstring(lua, -1)) { - err_info->msg = sdsnew(lua_tostring(lua, -1)); + err_info->msg = lm_strcpy(lua_tostring(lua, -1)); } lua_pop(lua, 1); lua_getfield(lua, -1, "source"); if (lua_isstring(lua, -1)) { - err_info->source = sdsnew(lua_tostring(lua, -1)); + err_info->source = lm_strcpy(lua_tostring(lua, -1)); } lua_pop(lua, 1); lua_getfield(lua, -1, "line"); if (lua_isstring(lua, -1)) { - err_info->line = sdsnew(lua_tostring(lua, -1)); + err_info->line = lm_strcpy(lua_tostring(lua, -1)); } lua_pop(lua, 1); @@ -1714,15 +1876,15 @@ void luaExtractErrorInformation(lua_State *lua, errorInfo *err_info) { if (err_info->msg == NULL) { /* Ensure we never return a NULL msg. */ - err_info->msg = sdsnew("ERR unknown error"); + err_info->msg = lm_strcpy("ERR unknown error"); } } /* This is the core of our Lua debugger, called each time Lua is about * to start executing a new line. */ void luaLdbLineHook(lua_State *lua, lua_Debug *ar) { - scriptRunCtx *rctx = luaGetFromRegistry(lua, REGISTRY_RUN_CTX_NAME); - serverAssert(rctx); /* Only supported inside script invocation */ + ValkeyModuleScriptingEngineServerRuntimeCtx *rctx = luaGetFromRegistry(lua, REGISTRY_RUN_CTX_NAME); + ValkeyModule_Assert(rctx); /* Only supported inside script invocation */ lua_getstack(lua, 0, ar); lua_getinfo(lua, "Sl", ar); ldbSetCurrentLine(ar->currentline); @@ -1735,14 +1897,14 @@ void luaLdbLineHook(lua_State *lua, lua_Debug *ar) { /* Check if a timeout occurred. */ if (ar->event == LUA_HOOKCOUNT && !ldbIsStepEnabled() && bp == 0) { - mstime_t elapsed = elapsedMs(rctx->start_time); - mstime_t timelimit = server.busy_reply_threshold ? server.busy_reply_threshold : 5000; - if (elapsed >= timelimit) { - timeout = 1; - ldbSetStepMode(1); - } else { - return; /* No timeout, ignore the COUNT event. */ - } + // mstime_t elapsed = elapsedMs(rctx->start_time); + // mstime_t timelimit = server.busy_reply_threshold ? server.busy_reply_threshold : 5000; + // if (elapsed >= timelimit) { + // timeout = 1; + // ldbSetStepMode(1); + // } else { + return; /* No timeout, ignore the COUNT event. */ + // } } if (ldbIsStepEnabled() || bp) { @@ -1753,7 +1915,8 @@ void luaLdbLineHook(lua_State *lua, lua_Debug *ar) { reason = "timeout reached, infinite loop?"; ldbSetStepMode(0); ldbSetBreakpointOnNextLine(0); - ldbLog(sdscatprintf(sdsempty(), "* Stopped at %d, stop reason = %s", ldbGetCurrentLine(), reason)); + ValkeyModuleString *msg = ValkeyModule_CreateStringPrintf(NULL, "* Stopped at %d, stop reason = %s", ldbGetCurrentLine(), reason); + ldbLog(msg); ldbLogSourceLine(ldbGetCurrentLine()); ldbSendLogs(); if (ldbRepl(lua) == C_ERR && timeout) { @@ -1763,27 +1926,39 @@ void luaLdbLineHook(lua_State *lua, lua_Debug *ar) { luaPushError(lua, "timeout during Lua debugging with client closing connection"); luaError(lua); } - rctx->start_time = getMonotonicUs(); + // rctx->start_time = getMonotonicUs(); } } -void luaCallFunction(scriptRunCtx *run_ctx, +void luaCallFunction(ValkeyModuleCtx *ctx, + ValkeyModuleScriptingEngineServerRuntimeCtx *run_ctx, + ValkeyModuleScriptingEngineSubsystemType type, lua_State *lua, - robj **keys, + ValkeyModuleString **keys, size_t nkeys, - robj **args, + ValkeyModuleString **args, size_t nargs, - int debug_enabled) { - client *c = run_ctx->original_client; + int debug_enabled, + int lua_enable_insecure_api) { int delhook = 0; /* We must set it before we set the Lua hook, theoretically the * Lua hook might be called wheneven we run any Lua instruction * such as 'luaSetGlobalArray' and we want the run_ctx to be available * each time the Lua hook is invoked. */ - luaSaveOnRegistry(lua, REGISTRY_RUN_CTX_NAME, run_ctx); - if (server.busy_reply_threshold > 0 && !debug_enabled) { + luaFuncCallCtx call_ctx = { + .module_ctx = ctx, + .run_ctx = run_ctx, + .type = type, + .replication_flags = PROPAGATE_AOF | PROPAGATE_REPL, + .resp = 2, + .lua_enable_insecure_api = lua_enable_insecure_api, + }; + + luaSaveOnRegistry(lua, REGISTRY_RUN_CTX_NAME, &call_ctx); + + if (!debug_enabled) { lua_sethook(lua, luaMaskCountHook, LUA_MASKCOUNT, 100000); delhook = 1; } else if (debug_enabled) { @@ -1795,14 +1970,14 @@ void luaCallFunction(scriptRunCtx *run_ctx, * EVAL received. */ luaCreateArray(lua, keys, nkeys); /* On eval, keys and arguments are globals. */ - if (run_ctx->flags & SCRIPT_EVAL_MODE) { + if (type == VMSE_EVAL) { /* open global protection to set KEYS */ lua_enablereadonlytable(lua, LUA_GLOBALSINDEX, 0); lua_setglobal(lua, "KEYS"); lua_enablereadonlytable(lua, LUA_GLOBALSINDEX, 1); } luaCreateArray(lua, args, nargs); - if (run_ctx->flags & SCRIPT_EVAL_MODE) { + if (type == VMSE_EVAL) { /* open global protection to set ARGV */ lua_enablereadonlytable(lua, LUA_GLOBALSINDEX, 0); lua_setglobal(lua, "ARGV"); @@ -1816,7 +1991,7 @@ void luaCallFunction(scriptRunCtx *run_ctx, * On function mode, we pass 2 arguments (the keys and args tables), * and the error handler is located on position -4 (stack: error_handler, callback, keys, args) */ int err; - if (run_ctx->flags & SCRIPT_EVAL_MODE) { + if (type == VMSE_EVAL) { err = lua_pcall(lua, 0, 1, -2); } else { err = lua_pcall(lua, 2, 1, -4); @@ -1848,24 +2023,33 @@ void luaCallFunction(scriptRunCtx *run_ctx, if (lua_isstring(lua, -1)) { msg = lua_tostring(lua, -1); } - addReplyErrorFormat(c, "Error running script %s, %.100s\n", run_ctx->funcname, msg); + ValkeyModule_ReplyWithErrorFormat(ctx, "ERR Error running script, %.100s\n", msg); } else { errorInfo err_info = {0}; - sds final_msg = sdsempty(); luaExtractErrorInformation(lua, &err_info); - final_msg = sdscatfmt(final_msg, "-%s", err_info.msg); if (err_info.line && err_info.source) { - final_msg = - sdscatfmt(final_msg, " script: %s, on %s:%s.", run_ctx->funcname, err_info.source, err_info.line); + ValkeyModule_ReplyWithCustomErrorFormat( + ctx, + !err_info.ignore_err_stats_update, + "%s script: on %s:%s.", + err_info.msg, + err_info.source, + err_info.line); + } else { + ValkeyModule_ReplyWithCustomErrorFormat( + ctx, + !err_info.ignore_err_stats_update, + "%s", + err_info.msg); } - addReplyErrorSdsEx(c, final_msg, err_info.ignore_err_stats_update ? ERR_REPLY_FLAG_NO_STATS_UPDATE : 0); luaErrorInformationDiscard(&err_info); } lua_pop(lua, 1); /* Consume the Lua error */ } else { /* On success convert the Lua return value into RESP, and * send it to * the client. */ - luaReplyToServerReply(c, run_ctx->c, lua); /* Convert and consume the reply. */ + + luaReplyToServerReply(ctx, call_ctx.resp, lua); /* Convert and consume the reply. */ } /* Perform some cleanup that we need to do both on error and success. */ diff --git a/src/lua/script_lua.h b/src/modules/lua/script_lua.h similarity index 81% rename from src/lua/script_lua.h rename to src/modules/lua/script_lua.h index 3ecbdf44c0..262e94c47f 100644 --- a/src/lua/script_lua.h +++ b/src/modules/lua/script_lua.h @@ -30,6 +30,8 @@ #ifndef __SCRIPT_LUA_H_ #define __SCRIPT_LUA_H_ +#include "engine_structs.h" + /* * script_lua.c unit provides shared functionality between * eval.c and function_lua.c. Functionality provided: @@ -48,47 +50,48 @@ * Uses script.c for interaction back with Redis. */ -#include "../server.h" -#include "../script.h" -#include -#include -#include +#define C_OK 0 +#define C_ERR -1 + +typedef struct lua_State lua_State; #define REGISTRY_RUN_CTX_NAME "__RUN_CTX__" +#define REGISTRY_MODULE_CTX_NAME "__MODULE_CTX__" #define REDIS_API_NAME "redis" #define SERVER_API_NAME "server" typedef struct errorInfo { - sds msg; - sds source; - sds line; + char *msg; + char *source; + char *line; int ignore_err_stats_update; } errorInfo; -void luaRegisterServerAPI(lua_State *lua); -robj *luaGetStringObject(lua_State *lua, int index); -void luaRegisterGlobalProtectionFunction(lua_State *lua); +void luaRegisterServerAPI(luaEngineCtx *ctx, lua_State *lua); void luaSetErrorMetatable(lua_State *lua); -void luaSetAllowListProtection(lua_State *lua); void luaSetTableProtectionRecursively(lua_State *lua); void luaSetTableProtectionForBasicTypes(lua_State *lua); void luaRegisterLogFunction(lua_State *lua); -void luaRegisterVersion(lua_State *lua); -void luaPushErrorBuff(lua_State *lua, sds err_buff); +void luaRegisterVersion(luaEngineCtx *ctx, lua_State *lua); void luaPushError(lua_State *lua, const char *error); int luaError(lua_State *lua); void luaSaveOnRegistry(lua_State *lua, const char *name, void *ptr); void *luaGetFromRegistry(lua_State *lua, const char *name); -void luaCallFunction(scriptRunCtx *r_ctx, +void luaCallFunction(ValkeyModuleCtx *ctx, + ValkeyModuleScriptingEngineServerRuntimeCtx *r_ctx, + ValkeyModuleScriptingEngineSubsystemType type, lua_State *lua, - robj **keys, + ValkeyModuleString **keys, size_t nkeys, - robj **args, + ValkeyModuleString **args, size_t nargs, - int debug_enabled); + int debug_enabled, + int lua_enable_insecure_api); void luaExtractErrorInformation(lua_State *lua, errorInfo *err_info); void luaErrorInformationDiscard(errorInfo *err_info); unsigned long luaMemory(lua_State *lua); +char *lm_strcpy(const char *str); + #endif /* __SCRIPT_LUA_H_ */ diff --git a/src/replication.c b/src/replication.c index ea6aa729a3..b6629f362a 100644 --- a/src/replication.c +++ b/src/replication.c @@ -4854,9 +4854,6 @@ void replicationRequestAckFromReplicas(void) { * returns the actual client woff */ long long getClientWriteOffset(client *c) { if (scriptIsRunning()) { - /* If a script is currently running, the client passed in is a fake - * client, and its woff is always 0. */ - serverAssert(scriptGetClient() == c); c = scriptGetCaller(); } return c->woff; diff --git a/src/script.c b/src/script.c index 23bf514458..4ea1ec6328 100644 --- a/src/script.c +++ b/src/script.c @@ -29,9 +29,7 @@ #include "server.h" #include "script.h" -#include "cluster.h" #include "cluster_slot_stats.h" -#include "module.h" scriptFlag scripts_flags_def[] = { {.flag = VMSE_SCRIPT_FLAG_NO_WRITES, .str = "no-writes"}, @@ -66,11 +64,6 @@ int scriptIsTimedout(void) { return scriptIsRunning() && (curr_run_ctx->flags & SCRIPT_TIMEDOUT); } -client *scriptGetClient(void) { - serverAssert(scriptIsRunning()); - return curr_run_ctx->c; -} - client *scriptGetCaller(void) { serverAssert(scriptIsRunning()); return curr_run_ctx->original_client; @@ -215,24 +208,10 @@ int scriptPrepareForRun(scriptRunCtx *run_ctx, run_ctx->engine = engine; - run_ctx->c = scriptingEngineGetClient(engine); run_ctx->original_client = caller; run_ctx->funcname = funcname; run_ctx->slot = caller->slot; - client *script_client = run_ctx->c; - client *curr_client = run_ctx->original_client; - - /* Select the right DB in the context of the Lua client */ - selectDb(script_client, curr_client->db->id); - script_client->resp = 2; /* Default is RESP2, scripts can change it. */ - - /* If we are in MULTI context, flag Lua client as CLIENT_MULTI. */ - if (curr_client->flag.multi) { - script_client->flag.multi = 1; - initClientMultiState(script_client); - } - run_ctx->start_time = getMonotonicUs(); run_ctx->flags = 0; @@ -264,9 +243,6 @@ int scriptPrepareForRun(scriptRunCtx *run_ctx, void scriptResetRun(scriptRunCtx *run_ctx) { serverAssert(curr_run_ctx); - /* After the script done, remove the MULTI state. */ - run_ctx->c->flag.multi = 0; - if (scriptIsTimedout()) { exitScriptTimedoutMode(run_ctx); /* Restore the client that was protected when the script timeout @@ -329,272 +305,6 @@ void scriptKill(client *c, int is_eval) { addReply(c, shared.ok); } -static int scriptVerifyCommandArity(struct serverCommand *cmd, int argc, sds *err) { - if (!cmd || ((cmd->arity > 0 && cmd->arity != argc) || (argc < -cmd->arity))) { - if (cmd) - *err = sdsnew("Wrong number of args calling command from script"); - else - *err = sdsnew("Unknown command called from script"); - return C_ERR; - } - return C_OK; -} - -static int scriptVerifyACL(client *c, sds *err) { - /* Check the ACLs. */ - int acl_errpos; - int acl_retval = ACLCheckAllPerm(c, &acl_errpos); - if (acl_retval != ACL_OK) { - addACLLogEntry(c, acl_retval, ACL_LOG_CTX_LUA, acl_errpos, NULL, NULL); - sds msg = getAclErrorMessage(acl_retval, c->user, c->cmd, c->argv[acl_errpos]->ptr, 0); - *err = sdscatsds(sdsnew("ACL failure in script: "), msg); - sdsfree(msg); - return C_ERR; - } - return C_OK; -} - -static int scriptVerifyWriteCommandAllow(scriptRunCtx *run_ctx, char **err) { - /* A write command, on an RO command or an RO script is rejected ASAP. - * Note: For scripts, we consider may-replicate commands as write commands. - * This also makes it possible to allow read-only scripts to be run during - * CLIENT PAUSE WRITE. */ - if (run_ctx->flags & SCRIPT_READ_ONLY && (run_ctx->c->cmd->flags & (CMD_WRITE | CMD_MAY_REPLICATE))) { - *err = sdsnew("Write commands are not allowed from read-only scripts."); - return C_ERR; - } - - /* The other checks below are on the server state and are only relevant for - * write commands, return if this is not a write command. */ - if (!(run_ctx->c->cmd->flags & CMD_WRITE)) return C_OK; - - /* If the script already made a modification to the dataset, we can't - * fail it on unpredictable error state. */ - if ((run_ctx->flags & SCRIPT_WRITE_DIRTY)) return C_OK; - - /* Write commands are forbidden against read-only replicas, or if a - * command marked as non-deterministic was already called in the context - * of this script. */ - int deny_write_type = writeCommandsDeniedByDiskError(); - - if (server.primary_host && server.repl_replica_ro && !mustObeyClient(run_ctx->original_client)) { - *err = sdsdup(shared.roreplicaerr->ptr); - return C_ERR; - } - - if (deny_write_type != DISK_ERROR_TYPE_NONE) { - *err = writeCommandsGetDiskErrorMessage(deny_write_type); - return C_ERR; - } - - /* Don't accept write commands if there are not enough good replicas and - * user configured the min-replicas-to-write option. Note this only reachable - * for Eval scripts that didn't declare flags, see the other check in - * scriptPrepareForRun */ - if (!checkGoodReplicasStatus()) { - *err = sdsdup(shared.noreplicaserr->ptr); - return C_ERR; - } - - return C_OK; -} - -static int scriptVerifyOOM(scriptRunCtx *run_ctx, char **err) { - if (run_ctx->flags & SCRIPT_ALLOW_OOM) { - /* Allow running any command even if OOM reached */ - return C_OK; - } - - /* If we reached the memory limit configured via maxmemory, commands that - * could enlarge the memory usage are not allowed, but only if this is the - * first write in the context of this script, otherwise we can't stop - * in the middle. */ - - if (server.maxmemory && /* Maxmemory is actually enabled. */ - !mustObeyClient(run_ctx->original_client) && /* Don't care about mem for replicas or AOF. */ - !(run_ctx->flags & SCRIPT_WRITE_DIRTY) && /* Script had no side effects so far. */ - server.pre_command_oom_state && /* Detected OOM when script start. */ - (run_ctx->c->cmd->flags & CMD_DENYOOM)) { - *err = sdsdup(shared.oomerr->ptr); - return C_ERR; - } - - return C_OK; -} - -static int scriptVerifyClusterState(scriptRunCtx *run_ctx, client *c, client *original_c, sds *err) { - if (!server.cluster_enabled || mustObeyClient(original_c)) { - return C_OK; - } - /* If this is a Cluster node, we need to make sure the script is not - * trying to access non-local keys, with the exception of commands - * received from our primary or when loading the AOF back in memory. */ - int error_code; - /* Duplicate relevant flags in the script client. */ - c->flag.readonly = original_c->flag.readonly; - c->flag.asking = original_c->flag.asking; - int hashslot = c->slot = clusterSlotByCommand(c->cmd, c->argv, c->argc, &c->read_flags); - if (getNodeByQuery(c, &error_code) != getMyClusterNode()) { - if (error_code == CLUSTER_REDIR_DOWN_RO_STATE) { - *err = sdsnew("Script attempted to execute a write command while the " - "cluster is down and readonly"); - } else if (error_code == CLUSTER_REDIR_DOWN_STATE) { - *err = sdsnew("Script attempted to execute a command while the " - "cluster is down"); - } else if (error_code == CLUSTER_REDIR_CROSS_SLOT) { - *err = sdscatfmt(sdsempty(), - "Command '%S' in script attempted to access keys that don't hash to the same slot", - c->cmd->fullname); - } else if (error_code == CLUSTER_REDIR_UNSTABLE) { - /* The request spawns multiple keys in the same slot, - * but the slot is not "stable" currently as there is - * a migration or import in progress. */ - *err = sdscatfmt(sdsempty(), - "Unable to execute command '%S' in script " - "because undeclared keys were accessed during rehashing of the slot", - c->cmd->fullname); - } else if (error_code == CLUSTER_REDIR_DOWN_UNBOUND) { - *err = sdsnew("Script attempted to access a slot not served"); - } else { - /* error_code == CLUSTER_REDIR_MOVED || error_code == CLUSTER_REDIR_ASK */ - *err = sdsnew("Script attempted to access a non local key in a " - "cluster node"); - } - return C_ERR; - } - - /* If the script declared keys in advanced, the cross slot error would have - * already been thrown. This is only checking for cross slot keys being accessed - * that weren't pre-declared. */ - if (hashslot != -1 && !(run_ctx->flags & SCRIPT_ALLOW_CROSS_SLOT)) { - if (run_ctx->slot == -1) { - run_ctx->slot = hashslot; - } else if (run_ctx->slot != hashslot) { - *err = sdsnew("Script attempted to access keys that do not hash to " - "the same slot"); - return C_ERR; - } - } - - original_c->slot = hashslot; - - return C_OK; -} - -/* set RESP for a given run_ctx */ -int scriptSetResp(scriptRunCtx *run_ctx, int resp) { - if (resp != 2 && resp != 3) { - return C_ERR; - } - - run_ctx->c->resp = resp; - return C_OK; -} - -/* set Repl for a given run_ctx - * either: PROPAGATE_AOF | PROPAGATE_REPL*/ -int scriptSetRepl(scriptRunCtx *run_ctx, int repl) { - if ((repl & ~(PROPAGATE_AOF | PROPAGATE_REPL)) != 0) { - return C_ERR; - } - run_ctx->repl_flags = repl; - return C_OK; -} - -static int scriptVerifyAllowStale(client *c, sds *err) { - if (!server.primary_host) { - /* Not a replica, stale is irrelevant */ - return C_OK; - } - - if (server.repl_state == REPL_STATE_CONNECTED) { - /* Connected to replica, stale is irrelevant */ - return C_OK; - } - - if (server.repl_serve_stale_data == 1) { - /* Disconnected from replica but allow to serve data */ - return C_OK; - } - - if (c->cmd->flags & CMD_STALE) { - /* Command is allow while stale */ - return C_OK; - } - - /* On stale replica, can not run the command */ - *err = sdsnew("Can not execute the command on a stale replica"); - return C_ERR; -} - -/* Call a server command. - * The reply is written to the run_ctx client and it is - * up to the engine to take and parse. - * The err out variable is set only if error occurs and describe the error. - * If err is set on reply is written to the run_ctx client. */ -void scriptCall(scriptRunCtx *run_ctx, sds *err) { - client *c = run_ctx->c; - - /* Setup our fake client for command execution */ - c->user = run_ctx->original_client->user; - - /* Process module hooks */ - moduleCallCommandFilters(c); - - struct serverCommand *cmd = lookupCommand(c->argv, c->argc); - c->cmd = c->lastcmd = c->realcmd = cmd; - if (scriptVerifyCommandArity(cmd, c->argc, err) != C_OK) { - goto error; - } - - /* There are commands that are not allowed inside scripts. */ - if (!server.script_disable_deny_script && (cmd->flags & CMD_NOSCRIPT)) { - *err = sdscatprintf(sdsempty(), "This %s command is not allowed from script", server.extended_redis_compat ? "Redis" : "Valkey"); - goto error; - } - - if (scriptVerifyAllowStale(c, err) != C_OK) { - goto error; - } - - if (scriptVerifyACL(c, err) != C_OK) { - goto error; - } - - if (scriptVerifyWriteCommandAllow(run_ctx, err) != C_OK) { - goto error; - } - - if (scriptVerifyOOM(run_ctx, err) != C_OK) { - goto error; - } - - if (cmd->flags & CMD_WRITE) { - /* signify that we already change the data in this execution */ - run_ctx->flags |= SCRIPT_WRITE_DIRTY; - } - - if (scriptVerifyClusterState(run_ctx, c, run_ctx->original_client, err) != C_OK) { - goto error; - } - - int call_flags = CMD_CALL_NONE; - if (run_ctx->repl_flags & PROPAGATE_AOF) { - call_flags |= CMD_CALL_PROPAGATE_AOF; - } - if (run_ctx->repl_flags & PROPAGATE_REPL) { - call_flags |= CMD_CALL_PROPAGATE_REPL; - } - call(c, call_flags); - serverAssert(c->flag.blocked == 0); - scriptClusterSlotStatsInvalidateSlotIfApplicable(); - return; - -error: - afterErrorReply(c, *err, sdslen(*err), 0); - incrCommandStatsOnError(cmd, ERROR_COMMAND_REJECTED); -} - long long scriptRunDuration(void) { serverAssert(scriptIsRunning()); return elapsedMs(curr_run_ctx->start_time); diff --git a/src/script.h b/src/script.h index 4b361dd1d7..85887a0b86 100644 --- a/src/script.h +++ b/src/script.h @@ -30,6 +30,8 @@ #ifndef __SCRIPT_H_ #define __SCRIPT_H_ +#include "valkeymodule.h" + /* * Script.c unit provides an API for functions and eval * to interact with the server. Interaction includes mostly @@ -74,7 +76,6 @@ typedef struct scriptRunCtx scriptRunCtx; struct scriptRunCtx { scriptingEngine *engine; const char *funcname; - client *c; client *original_client; int flags; int repl_flags; @@ -106,16 +107,12 @@ int scriptPrepareForRun(scriptRunCtx *r_ctx, uint64_t script_flags, int ro); void scriptResetRun(scriptRunCtx *r_ctx); -int scriptSetResp(scriptRunCtx *r_ctx, int resp); -int scriptSetRepl(scriptRunCtx *r_ctx, int repl); -void scriptCall(scriptRunCtx *r_ctx, sds *err); int scriptInterrupt(scriptRunCtx *r_ctx); void scriptKill(client *c, int is_eval); int scriptIsRunning(void); const char *scriptCurrFunction(void); int scriptIsEval(void); int scriptIsTimedout(void); -client *scriptGetClient(void); client *scriptGetCaller(void); long long scriptRunDuration(void); diff --git a/src/scripting_engine.c b/src/scripting_engine.c index aa16e96991..f3e4265cdb 100644 --- a/src/scripting_engine.c +++ b/src/scripting_engine.c @@ -5,6 +5,7 @@ */ #include "scripting_engine.h" +#include "bio.h" #include "dict.h" #include "functions.h" #include "module.h" @@ -49,7 +50,6 @@ typedef struct scriptingEngine { sds name; /* Name of the engine */ ValkeyModule *module; /* the module that implements the scripting engine */ scriptingEngineImpl impl; /* engine context and callbacks to interact with the engine */ - client *client; /* Client that is used to run commands */ ValkeyModuleCtx *module_ctx_cache[MODULE_CTX_CACHE_SIZE]; /* Cache of module context objects */ } scriptingEngine; @@ -131,6 +131,7 @@ int scriptingEngineManagerRegister(const char *engine_name, ValkeyModule *engine_module, engineCtx *engine_ctx, engineMethods *engine_methods) { + serverAssert(engine_name != NULL); sds engine_name_sds = sdsnew(engine_name); if (dictFetchValue(engineMgr.engines, engine_name_sds)) { @@ -139,11 +140,6 @@ int scriptingEngineManagerRegister(const char *engine_name, return C_ERR; } - client *c = createClient(NULL); - c->flag.deny_blocking = 1; - c->flag.script = 1; - c->flag.fake = 1; - scriptingEngine *e = zmalloc(sizeof(*e)); *e = (scriptingEngine){ .name = engine_name_sds, @@ -151,7 +147,6 @@ int scriptingEngineManagerRegister(const char *engine_name, .impl = { .ctx = engine_ctx, }, - .client = c, .module_ctx_cache = {0}, }; scriptingEngineInitializeEngineMethods(e, engine_methods); @@ -191,7 +186,12 @@ int scriptingEngineManagerUnregister(const char *engine_name) { mem_info.engine_memory_overhead; sdsfree(e->name); - freeClient(e->client); + + /* We need to ensure that any pending async flush of eval scripts or + * functions have completed before freeing the module context cache, which + * may be used by the async jobs. */ + bioDrainWorker(BIO_LAZY_FREE); + for (size_t i = 0; i < MODULE_CTX_CACHE_SIZE; i++) { serverAssert(e->module_ctx_cache[i] != NULL); zfree(e->module_ctx_cache[i]); @@ -219,10 +219,6 @@ sds scriptingEngineGetName(scriptingEngine *engine) { return engine->name; } -client *scriptingEngineGetClient(scriptingEngine *engine) { - return engine->client; -} - ValkeyModule *scriptingEngineGetModule(scriptingEngine *engine) { return engine->module; } @@ -250,12 +246,18 @@ void scriptingEngineManagerForEachEngine(engineIterCallback callback, static ValkeyModuleCtx *engineSetupModuleCtx(int module_ctx_cache_index, scriptingEngine *e, + int add_script_execution_flag, + int add_thread_safe_flag, client *c) { serverAssert(e != NULL); if (e->module == NULL) return NULL; ValkeyModuleCtx *ctx = e->module_ctx_cache[module_ctx_cache_index]; - moduleScriptingEngineInitContext(ctx, e->module, c); + moduleScriptingEngineInitContext(ctx, + e->module, + add_script_execution_flag, + add_thread_safe_flag, + c); return ctx; } @@ -276,7 +278,7 @@ compiledFunction **scriptingEngineCallCompileCode(scriptingEngine *engine, robj **err) { serverAssert(type == VMSE_EVAL || type == VMSE_FUNCTION); compiledFunction **functions = NULL; - ValkeyModuleCtx *module_ctx = engineSetupModuleCtx(COMMON_MODULE_CTX_INDEX, engine, NULL); + ValkeyModuleCtx *module_ctx = engineSetupModuleCtx(COMMON_MODULE_CTX_INDEX, engine, false, false, NULL); if (engine->impl.methods.version == SCRIPTING_ENGINE_ABI_VERSION_1) { functions = engine->impl.methods.compile_code_v1( @@ -309,7 +311,7 @@ void scriptingEngineCallFreeFunction(scriptingEngine *engine, subsystemType type, compiledFunction *compiled_func) { serverAssert(type == VMSE_EVAL || type == VMSE_FUNCTION); - ValkeyModuleCtx *module_ctx = engineSetupModuleCtx(FREE_FUNCTION_MODULE_CTX_INDEX, engine, NULL); + ValkeyModuleCtx *module_ctx = engineSetupModuleCtx(FREE_FUNCTION_MODULE_CTX_INDEX, engine, false, true, NULL); engine->impl.methods.free_function( module_ctx, engine->impl.ctx, @@ -329,7 +331,7 @@ void scriptingEngineCallFunction(scriptingEngine *engine, size_t nargs) { serverAssert(type == VMSE_EVAL || type == VMSE_FUNCTION); - ValkeyModuleCtx *module_ctx = engineSetupModuleCtx(COMMON_MODULE_CTX_INDEX, engine, caller); + ValkeyModuleCtx *module_ctx = engineSetupModuleCtx(COMMON_MODULE_CTX_INDEX, engine, true, false, caller); engine->impl.methods.call_function( module_ctx, @@ -347,7 +349,7 @@ void scriptingEngineCallFunction(scriptingEngine *engine, size_t scriptingEngineCallGetFunctionMemoryOverhead(scriptingEngine *engine, compiledFunction *compiled_function) { - ValkeyModuleCtx *module_ctx = engineSetupModuleCtx(COMMON_MODULE_CTX_INDEX, engine, NULL); + ValkeyModuleCtx *module_ctx = engineSetupModuleCtx(COMMON_MODULE_CTX_INDEX, engine, false, false, NULL); size_t mem = engine->impl.methods.get_function_memory_overhead( module_ctx, compiled_function); @@ -358,7 +360,7 @@ size_t scriptingEngineCallGetFunctionMemoryOverhead(scriptingEngine *engine, callableLazyEnvReset *scriptingEngineCallResetEnvFunc(scriptingEngine *engine, subsystemType type, int async) { - ValkeyModuleCtx *module_ctx = engineSetupModuleCtx(COMMON_MODULE_CTX_INDEX, engine, NULL); + ValkeyModuleCtx *module_ctx = engineSetupModuleCtx(COMMON_MODULE_CTX_INDEX, engine, false, false, NULL); callableLazyEnvReset *callback = NULL; if (engine->impl.methods.version < SCRIPTING_ENGINE_ABI_VERSION_3) { @@ -393,7 +395,7 @@ callableLazyEnvReset *scriptingEngineCallResetEnvFunc(scriptingEngine *engine, engineMemoryInfo scriptingEngineCallGetMemoryInfo(scriptingEngine *engine, subsystemType type) { - ValkeyModuleCtx *module_ctx = engineSetupModuleCtx(GET_MEMORY_MODULE_CTX_INDEX, engine, NULL); + ValkeyModuleCtx *module_ctx = engineSetupModuleCtx(GET_MEMORY_MODULE_CTX_INDEX, engine, false, false, NULL); engineMemoryInfo mem_info = engine->impl.methods.get_memory_info( module_ctx, engine->impl.ctx, @@ -420,7 +422,7 @@ debuggerEnableRet scriptingEngineCallDebuggerEnable(scriptingEngine *engine, return VMSE_DEBUG_NOT_SUPPORTED; } - ValkeyModuleCtx *module_ctx = engineSetupModuleCtx(COMMON_MODULE_CTX_INDEX, engine, NULL); + ValkeyModuleCtx *module_ctx = engineSetupModuleCtx(COMMON_MODULE_CTX_INDEX, engine, false, false, NULL); debuggerEnableRet ret = engine->impl.methods.debugger_enable( module_ctx, engine->impl.ctx, @@ -436,7 +438,7 @@ void scriptingEngineCallDebuggerDisable(scriptingEngine *engine, serverAssert(engine->impl.methods.version >= SCRIPTING_ENGINE_ABI_VERSION_4); serverAssert(engine->impl.methods.debugger_disable != NULL); - ValkeyModuleCtx *module_ctx = engineSetupModuleCtx(COMMON_MODULE_CTX_INDEX, engine, NULL); + ValkeyModuleCtx *module_ctx = engineSetupModuleCtx(COMMON_MODULE_CTX_INDEX, engine, false, false, NULL); engine->impl.methods.debugger_disable( module_ctx, engine->impl.ctx, @@ -450,7 +452,7 @@ void scriptingEngineCallDebuggerStart(scriptingEngine *engine, serverAssert(engine->impl.methods.version >= SCRIPTING_ENGINE_ABI_VERSION_4); serverAssert(engine->impl.methods.debugger_start != NULL); - ValkeyModuleCtx *module_ctx = engineSetupModuleCtx(COMMON_MODULE_CTX_INDEX, engine, NULL); + ValkeyModuleCtx *module_ctx = engineSetupModuleCtx(COMMON_MODULE_CTX_INDEX, engine, false, false, NULL); engine->impl.methods.debugger_start( module_ctx, engine->impl.ctx, @@ -464,7 +466,7 @@ void scriptingEngineCallDebuggerEnd(scriptingEngine *engine, serverAssert(engine->impl.methods.version >= SCRIPTING_ENGINE_ABI_VERSION_4); serverAssert(engine->impl.methods.debugger_end != NULL); - ValkeyModuleCtx *module_ctx = engineSetupModuleCtx(COMMON_MODULE_CTX_INDEX, engine, NULL); + ValkeyModuleCtx *module_ctx = engineSetupModuleCtx(COMMON_MODULE_CTX_INDEX, engine, false, false, NULL); engine->impl.methods.debugger_end( module_ctx, engine->impl.ctx, diff --git a/src/scripting_engine.h b/src/scripting_engine.h index 7fec7c2e73..44c176f41c 100644 --- a/src/scripting_engine.h +++ b/src/scripting_engine.h @@ -51,7 +51,6 @@ void scriptingEngineManagerForEachEngine(engineIterCallback callback, void *cont * Engine API functions. */ sds scriptingEngineGetName(scriptingEngine *engine); -client *scriptingEngineGetClient(scriptingEngine *engine); ValkeyModule *scriptingEngineGetModule(scriptingEngine *engine); uint64_t scriptingEngineGetAbiVersion(scriptingEngine *engine); diff --git a/src/server.c b/src/server.c index 44bd09ad4e..3934c440bc 100644 --- a/src/server.c +++ b/src/server.c @@ -50,7 +50,7 @@ #include "sds.h" #include "module.h" #include "scripting_engine.h" -#include "lua/engine_lua.h" + #include "eval.h" #include "trace/trace_commands.h" @@ -3018,11 +3018,12 @@ void initServer(void) { * commands with `CMD_NOSCRIPT` flag are not allowed to run in scripts. */ server.script_disable_deny_script = 0; - /* Initialize the LUA scripting engine. */ - if (luaEngineInitEngine() != C_OK) { - serverPanic("Lua engine initialization failed, check the server logs."); - exit(1); - } + commandlogInit(); + latencyMonitorInit(); + initSharedQueryBuf(); + + /* Initialize ACL default password if it exists */ + ACLUpdateDefaultUserPassword(server.requirepass); /* Initialize the functions engine based off of LUA initialization. */ if (functionsInit() == C_ERR) { @@ -3032,13 +3033,6 @@ void initServer(void) { /* Initialize the EVAL scripting component. */ evalInit(); - commandlogInit(); - latencyMonitorInit(); - initSharedQueryBuf(); - - /* Initialize ACL default password if it exists */ - ACLUpdateDefaultUserPassword(server.requirepass); - applyWatchdogPeriod(); if (server.maxmemory_clients != 0) initServerClientMemUsageBuckets(); @@ -4820,6 +4814,8 @@ int finishShutdown(void) { /* Close the listening sockets. Apparently this allows faster restarts. */ closeListeningSockets(1); + moduleUnloadAllModules(); + serverLog(LL_WARNING, "%s is now ready to exit, bye bye...", server.sentinel_mode ? "Sentinel" : "Valkey"); return C_OK; @@ -7419,6 +7415,18 @@ __attribute__((weak)) int main(int argc, char **argv) { if (server.cluster_enabled) { clusterInitLast(); } + + /* Initialize the LUA scripting engine. */ +#ifdef LUA_ENGINE_ENABLED +#define LUA_ENGINE_LIB_STR STRINGIFY(LUA_ENGINE_LIB) + if (scriptingEngineManagerFind("lua") == NULL) { + if (moduleLoad(LUA_ENGINE_LIB_STR, NULL, 0, 0) != C_OK) { + serverPanic("Lua engine initialization failed, check the server logs."); + } + } +#endif + + InitServerLast(); if (!server.sentinel_mode) { diff --git a/tests/modules/datatype2.c b/tests/modules/datatype2.c index 642b64fe29..58ca23bf63 100644 --- a/tests/modules/datatype2.c +++ b/tests/modules/datatype2.c @@ -737,6 +737,24 @@ int ValkeyModule_OnLoad(ValkeyModuleCtx *ctx, ValkeyModuleString **argv, int arg ValkeyModule_SubscribeToServerEvent(ctx, ValkeyModuleEvent_FlushDB, flushdbCallback); ValkeyModule_SubscribeToServerEvent(ctx, ValkeyModuleEvent_SwapDB, swapDbCallback); - + return VALKEYMODULE_OK; } + +int ValkeyModule_OnUnload(ValkeyModuleCtx *ctx) { + VALKEYMODULE_NOT_USED(ctx); + + for(int i = 0; i < MAX_DB; i++){ + ValkeyModuleString *key; + void *tdata; + ValkeyModuleDictIter *iter = ValkeyModule_DictIteratorStartC(mem_pool[i], "^", NULL, 0); + while((key = ValkeyModule_DictNext(ctx, iter, &tdata)) != NULL) { + MemBlockFree((struct MemBlock *)tdata); + ValkeyModule_FreeString(ctx, key); + } + ValkeyModule_DictIteratorStop(iter); + ValkeyModule_FreeDict(NULL, mem_pool[i]); + } + + return VALKEYMODULE_OK; +} \ No newline at end of file diff --git a/tests/modules/hooks.c b/tests/modules/hooks.c index a9a9d27b2d..fb3c9b86e6 100644 --- a/tests/modules/hooks.c +++ b/tests/modules/hooks.c @@ -116,7 +116,7 @@ void clearEvents(ValkeyModuleCtx *ctx) { ValkeyModuleString *key; EventElement *event; - ValkeyModuleDictIter *iter = ValkeyModule_DictIteratorStart(event_log, "^", NULL); + ValkeyModuleDictIter *iter = ValkeyModule_DictIteratorStartC(event_log, "^", NULL, 0); while((key = ValkeyModule_DictNext(ctx, iter, (void**)&event)) != NULL) { event->count = 0; event->last_val_int = 0; @@ -124,6 +124,8 @@ void clearEvents(ValkeyModuleCtx *ctx) event->last_val_string = NULL; ValkeyModule_DictDel(event_log, key, NULL); ValkeyModule_Free(event); + ValkeyModule_DictIteratorReseek(iter, ">=", key); + ValkeyModule_FreeString(ctx, key); } ValkeyModule_DictIteratorStop(iter); } diff --git a/tests/modules/usercall.c b/tests/modules/usercall.c index b61c919fc1..e2bbb105c7 100644 --- a/tests/modules/usercall.c +++ b/tests/modules/usercall.c @@ -227,3 +227,14 @@ int ValkeyModule_OnLoad(ValkeyModuleCtx *ctx, ValkeyModuleString **argv, int arg return VALKEYMODULE_OK; } + +int ValkeyModule_OnUnload(ValkeyModuleCtx *ctx) { + VALKEYMODULE_NOT_USED(ctx); + + if (user != NULL) { + ValkeyModule_FreeModuleUser(user); + user = NULL; + } + + return VALKEYMODULE_OK; +} diff --git a/tests/unit/moduleapi/hooks.tcl b/tests/unit/moduleapi/hooks.tcl index f0e3559af6..8376883361 100644 --- a/tests/unit/moduleapi/hooks.tcl +++ b/tests/unit/moduleapi/hooks.tcl @@ -373,7 +373,7 @@ tags "modules" { # look into the log file of the server that just exited test {Test shutdown hook} { - assert_equal [string match {*module-event-shutdown*} [exec tail -5 < $replica_stdout]] 1 + assert_equal [string match {*module-event-shutdown*} [exec tail -6 < $replica_stdout]] 1 } } diff --git a/tests/unit/moduleapi/scriptingengine.tcl b/tests/unit/moduleapi/scriptingengine.tcl index 47d5f96f8a..6a80381d6d 100644 --- a/tests/unit/moduleapi/scriptingengine.tcl +++ b/tests/unit/moduleapi/scriptingengine.tcl @@ -469,7 +469,7 @@ start_server {tags {"modules"}} { assert_match "*name=HELLO*" $info # Verify LUA is built-in and HELLO is from module - assert_match "*name=LUA,module=built-in*" $info + assert_match "*name=LUA,module=lua*" $info assert_match "*name=HELLO,module=helloengine*" $info }