From 3dc58a355233a8c19576ada207ec0d110151cf36 Mon Sep 17 00:00:00 2001 From: Paul Date: Tue, 3 Dec 2024 09:46:47 -0500 Subject: [PATCH] Cache WTSE Evaluation (#7896) * Cache WTSE Evaluation Make sure init and generate are called the smallest possible time makes getFrame basically free the second time due to cache invalidates only when resoultion etc change changes the ui so dragging frame count or resolution just activtes apply button and doesn't recalc --- src/common/dsp/WavetableScriptEvaluator.cpp | 401 +++++++++++--------- src/common/dsp/WavetableScriptEvaluator.h | 14 +- src/surge-testrunner/UnitTestsLUA.cpp | 21 +- src/surge-xt/gui/overlays/LuaEditors.cpp | 15 +- 4 files changed, 243 insertions(+), 208 deletions(-) diff --git a/src/common/dsp/WavetableScriptEvaluator.cpp b/src/common/dsp/WavetableScriptEvaluator.cpp index ca6dd8d2492..8cf30baf249 100644 --- a/src/common/dsp/WavetableScriptEvaluator.cpp +++ b/src/common/dsp/WavetableScriptEvaluator.cpp @@ -24,6 +24,9 @@ #include "LuaSupport.h" #include "lua/LuaSources.h" +// #define LOG(...) std::cout << __FILE__ << ":" << __LINE__ << " " << __VA_ARGS__ << std::endl; +#define LOG(...) + namespace Surge { namespace WavetableScript @@ -39,23 +42,18 @@ struct LuaWTEvaluator::Details size_t resolution{2048}; size_t frameCount{10}; - bool needsParse{false}; + bool isValid{false}; + std::vector> frameCache; + std::string wtName{"Scripted Wavetable"}; + void prepareIfInvalid(); lua_State *L{nullptr}; - void prepare() + void invalidate() { - if (L == nullptr) - { - L = lua_open(); - luaL_openlibs(L); - - auto wg = Surge::LuaSupport::SGLD("WavetableScript::prelude", L); - - Surge::LuaSupport::loadSurgePrelude(L, Surge::LuaSources::wtse_prelude); - } + isValid = false; + frameCache.clear(); } - void makeEmptyState(bool pushToGlobal) { lua_createtable(L, 0, 10); @@ -67,9 +65,134 @@ struct LuaWTEvaluator::Details if (pushToGlobal) lua_setglobal(L, statetable); } + + LuaWTEvaluator::frame_t generateScriptAtFrame(size_t frame) + { + LOG("generateScriptAtFrame " << frame); + auto &eqn = script; + + if (!makeValid()) + return std::nullopt; + + LuaWTEvaluator::frame_t res{std::nullopt}; + auto values = std::vector(); + + auto wgp = Surge::LuaSupport::SGLD("WavetableScript::evaluateInner", L); + lua_getglobal(L, "generate"); + if (!lua_isfunction(L, -1)) + { + if (storage) + storage->reportError("Unable to locate generate function", + "Wavetable Script Evaluator"); + return std::nullopt; + } + Surge::LuaSupport::setSurgeFunctionEnvironment(L); + + lua_createtable(L, 0, 10); + + lua_getglobal(L, statetable); + + lua_pushnil(L); /* first key */ + assert(lua_istable(L, -2)); + bool useLegacyNames{false}; + + while (lua_next(L, -2) != 0) + { + // stack is now new > global > k > v but we want to see if k 'legacy_config' is + // true + if (lua_isstring(L, -2)) + { + if (strcmp(lua_tostring(L, -2), "legacy_config") == 0) + { + if (lua_isboolean(L, -1)) + { + useLegacyNames = lua_toboolean(L, -1); + } + } + } + // stack is now new > global > k > v + lua_pushvalue(L, -2); + // stack is now new > global > k > v > k + lua_insert(L, -2); + // stack is now new > global > k > k > v + lua_settable(L, -5); + // stack is now new > global > k and k/v is inserted into new so we can iterate + } + // pop the remaining key + lua_pop(L, 1); + + if (useLegacyNames) + { + // xs is an array of the x locations in phase space + lua_createtable(L, resolution, 0); + double dp = 1.0 / (resolution - 1); + for (auto i = 0; i < resolution; ++i) + { + lua_pushinteger(L, i + 1); // lua has a 1 based index convention + lua_pushnumber(L, i * dp); + lua_settable(L, -3); + } + lua_setfield(L, -2, "xs"); + + lua_pushinteger(L, frame + 1); + lua_setfield(L, -2, "n"); + + lua_pushinteger(L, frameCount); + lua_setfield(L, -2, "nTables"); + } + + lua_pushinteger(L, frame + 1); + lua_setfield(L, -2, "frame"); + + lua_pushinteger(L, frameCount); + lua_setfield(L, -2, "frame_count"); + + lua_pushinteger(L, resolution); + lua_setfield(L, -2, "sample_count"); + + // So stack is now the table and the function + auto pcr = lua_pcall(L, 1, 1, 0); + if (pcr == LUA_OK) + { + if (lua_istable(L, -1)) + { + bool gen{true}; + for (auto i = 0; i < resolution && gen; ++i) + { + lua_pushinteger(L, i + 1); + lua_gettable(L, -2); + if (lua_isnumber(L, -1)) + { + values.push_back(lua_tonumber(L, -1)); + } + else + { + values.push_back(0.f); + gen = false; + } + lua_pop(L, 1); + } + if (gen) + res = values; + } + } + else + { + // If pcr is not LUA_OK then lua pushes an error string onto the stack. Show this + // error + std::string luaerr = lua_tostring(L, -1); + if (storage) + storage->reportError(luaerr, "Wavetable Evaluator Runtime Error"); + else + std::cerr << luaerr; + } + lua_pop(L, 1); // Error string or pcall result + + return res; + } void callInitFn() { - prepare(); + LOG("callInitFn"); auto wg = Surge::LuaSupport::SGLD("WavetableScript::details::callInitFn", L); lua_getglobal(L, "init"); @@ -112,11 +235,25 @@ struct LuaWTEvaluator::Details } } } - bool parseIfNeeded() + + bool makeValid() { - prepare(); - if (needsParse) + if (L == nullptr) { + LOG("creating Lua State "); + + L = lua_open(); + luaL_openlibs(L); + + auto wg = Surge::LuaSupport::SGLD("WavetableScript::prelude", L); + + Surge::LuaSupport::loadSurgePrelude(L, Surge::LuaSources::wtse_prelude); + } + + if (!isValid) + { + LOG("Validating"); + { // Have a separate guard for this just to make sure I match auto lwg = Surge::LuaSupport::SGLD("WavetableScript::details::clearGlobals", L); @@ -126,9 +263,14 @@ struct LuaWTEvaluator::Details lua_setglobal(L, "init"); lua_pushnil(L); lua_setglobal(L, statetable); + wtName = "Scripted Wavetable"; + + frameCache.clear(); + for (int i = 0; i < frameCount; ++i) + frameCache.push_back(std::nullopt); } - auto wg = Surge::LuaSupport::SGLD("WavetableScript::details::parseIfNeeded", L); + auto wg = Surge::LuaSupport::SGLD("WavetableScript::details::makeValid", L); std::string emsg; auto res = Surge::LuaSupport::parseStringDefiningMultipleFunctions( L, script, {"init", "generate"}, emsg); @@ -141,14 +283,29 @@ struct LuaWTEvaluator::Details callInitFn(); - needsParse = false; + { + auto wgn = + Surge::LuaSupport::SGLD("WavetableScript::details::makeValid::wtName", L); + lua_getglobal(L, statetable); + if (lua_istable(L, -1)) + { + lua_getfield(L, -1, "name"); + if (lua_isstring(L, -1)) + { + wtName = lua_tostring(L, -1); + } + + lua_pop(L, -1); + } + + lua_pop(L, -1); + } + + isValid = true; return res; } - else - { - return true; - } + return true; } }; #else @@ -173,163 +330,60 @@ void LuaWTEvaluator::setScript(const std::string &e) if (e != details->script) { details->script = e; - details->needsParse = true; + details->invalidate(); } #endif } void LuaWTEvaluator::setResolution(size_t r) { #if HAS_LUA - details->resolution = r; + if (r != details->resolution) + { + details->resolution = r; + details->invalidate(); + } #endif } void LuaWTEvaluator::setFrameCount(size_t n) { #if HAS_LUA - details->frameCount = n; + if (n != details->frameCount) + { + details->invalidate(); + details->frameCount = n; + } #endif } -std::optional> LuaWTEvaluator::evaluateScriptAtFrame(size_t frame) +LuaWTEvaluator::frame_t LuaWTEvaluator::getFrame(size_t frame) { #if HAS_LUA - auto storage = details->storage; - auto &eqn = details->script; - auto resolution = details->resolution; - auto nFrames = details->frameCount; - - std::optional> res{std::nullopt}; - auto values = std::vector(); - - details->prepare(); - auto L = details->L; - - auto wg = Surge::LuaSupport::SGLD("WavetableScript::evaluate", L); - - if (details->parseIfNeeded()) + if (!details->makeValid()) + return std::nullopt; + if (frame > details->frameCount) + return std::nullopt; + assert(frame < details->frameCache.size()); + if (!details->frameCache[frame].has_value()) { - auto wgp = Surge::LuaSupport::SGLD("WavetableScript::evaluateInner", L); - lua_getglobal(details->L, "generate"); - if (!lua_isfunction(details->L, -1)) - { - if (storage) - storage->reportError("Unable to locate generate function", - "Wavetable Script Evaluator"); - return std::nullopt; - } - Surge::LuaSupport::setSurgeFunctionEnvironment(L); - - lua_createtable(L, 0, 10); - - lua_getglobal(L, statetable); - - lua_pushnil(L); /* first key */ - assert(lua_istable(L, -2)); - bool useLegacyNames{true}; - - while (lua_next(L, -2) != 0) - { - // stack is now new > global > k > v but we want to see if k 'legacy_config' is true - if (lua_isstring(L, -2)) - { - if (strcmp(lua_tostring(L, -2), "legacy_config") == 0) - { - if (lua_isboolean(L, -1)) - { - useLegacyNames = lua_toboolean(L, -1); - } - } - } - // stack is now new > global > k > v - lua_pushvalue(L, -2); - // stack is now new > global > k > v > k - lua_insert(L, -2); - // stack is now new > global > k > k > v - lua_settable(L, -5); - // stack is now new > global > k and k/v is inserted into new so we can iterate - } - // pop the remaining key - lua_pop(L, 1); - - if (useLegacyNames) - { - // xs is an array of the x locations in phase space - lua_createtable(L, resolution, 0); - double dp = 1.0 / (resolution - 1); - for (auto i = 0; i < resolution; ++i) - { - lua_pushinteger(L, i + 1); // lua has a 1 based index convention - lua_pushnumber(L, i * dp); - lua_settable(L, -3); - } - lua_setfield(L, -2, "xs"); - - lua_pushinteger(L, frame + 1); - lua_setfield(L, -2, "n"); - - lua_pushinteger(L, nFrames); - lua_setfield(L, -2, "nTables"); - } - - lua_pushinteger(L, frame + 1); - lua_setfield(L, -2, "frame"); - - lua_pushinteger(L, nFrames); - lua_setfield(L, -2, "frame_count"); - - lua_pushinteger(L, resolution); - lua_setfield(L, -2, "sample_count"); - - // So stack is now the table and the function - auto pcr = lua_pcall(L, 1, 1, 0); - if (pcr == LUA_OK) - { - if (lua_istable(L, -1)) - { - bool gen{true}; - for (auto i = 0; i < resolution && gen; ++i) - { - lua_pushinteger(L, i + 1); - lua_gettable(L, -2); - if (lua_isnumber(L, -1)) - { - values.push_back(lua_tonumber(L, -1)); - } - else - { - values.push_back(0.f); - gen = false; - } - lua_pop(L, 1); - } - if (gen) - res = values; - } - } - else - { - // If pcr is not LUA_OK then lua pushes an error string onto the stack. Show this error - std::string luaerr = lua_tostring(L, -1); - if (storage) - storage->reportError(luaerr, "Wavetable Evaluator Runtime Error"); - else - std::cerr << luaerr; - } - lua_pop(L, 1); // Error string or pcall result + details->frameCache[frame] = details->generateScriptAtFrame(frame); } - - return res; + if (details->frameCache[frame].has_value()) + { + return *(details->frameCache[frame]); + } + return std::nullopt; #else return std::nullopt; #endif } -bool LuaWTEvaluator::constructWavetable(wt_header &wh, float **wavdata) +bool LuaWTEvaluator::populateWavetable(wt_header &wh, float **wavdata) { #if HAS_LUA - auto storage = details->storage; - auto &eqn = details->script; + if (!details->makeValid()) + return false; + auto resolution = details->resolution; auto frames = details->frameCount; @@ -339,12 +393,9 @@ bool LuaWTEvaluator::constructWavetable(wt_header &wh, float **wavdata) wh.flags = 0; *wavdata = wd; - details->prepare(); - details->parseIfNeeded(); - details->callInitFn(); for (int i = 0; i < frames; ++i) { - auto v = evaluateScriptAtFrame(i); + auto v = getFrame(i); if (v.has_value()) { memcpy(&(wd[i * resolution]), &((*v)[0]), resolution * sizeof(float)); @@ -362,39 +413,11 @@ bool LuaWTEvaluator::constructWavetable(wt_header &wh, float **wavdata) std::string LuaWTEvaluator::getSuggestedWavetableName() { - std::string res = "Scripted Wavetable"; - #if HAS_LUA - details->prepare(); - details->parseIfNeeded(); - details->callInitFn(); - - auto L = details->L; - auto wgp = Surge::LuaSupport::SGLD("WavetableScript::evaluateInner", L); - lua_getglobal(L, statetable); - if (lua_istable(L, -1)) - { - lua_getfield(L, -1, "name"); - if (lua_isstring(L, -1)) - { - res = lua_tostring(L, -1); - } - - lua_pop(L, -1); - } - - lua_pop(L, -1); -#endif - - return res; -} - -void LuaWTEvaluator::prepare() -{ -#if HAS_LUA - details->prepare(); - details->parseIfNeeded(); - details->callInitFn(); + details->makeValid(); + return details->wtName; +#else + return ""; #endif } diff --git a/src/common/dsp/WavetableScriptEvaluator.h b/src/common/dsp/WavetableScriptEvaluator.h index 63100677615..2f5d9610b2b 100644 --- a/src/common/dsp/WavetableScriptEvaluator.h +++ b/src/common/dsp/WavetableScriptEvaluator.h @@ -43,20 +43,16 @@ struct LuaWTEvaluator void setResolution(size_t); void setFrameCount(size_t); - void prepare(); - - /* - * Unlike the LFO modulator this is called at render time of the wavetable - * not at the evaluation or synthesis time. As such I expect you call it from - * one thread at a time and just you know generally be careful. - */ - std::optional> evaluateScriptAtFrame(size_t frame); + using validFrame_t = std::vector; + using frame_t = std::optional; /* * Generate all the data required to call BuildWT. The wavdata here is data you * must free with delete[] */ - bool constructWavetable(wt_header &wh, float **wavdata); + bool populateWavetable(wt_header &wh, float **wavdata); + + frame_t getFrame(size_t frame); std::string getSuggestedWavetableName(); diff --git a/src/surge-testrunner/UnitTestsLUA.cpp b/src/surge-testrunner/UnitTestsLUA.cpp index 1373a6df7ee..3154dcc1ff9 100644 --- a/src/surge-testrunner/UnitTestsLUA.cpp +++ b/src/surge-testrunner/UnitTestsLUA.cpp @@ -665,10 +665,21 @@ TEST_CASE("Wavetable Script", "[formula]") { SurgeStorage storage; const std::string s = R"FN( -function generate(config) - res = config.xs - for i,x in ipairs(config.xs) do - res[i] = math.sin(2 * math.pi * x * config.n) + +function init(wt) + -- wt will have frame_count and sample_count defined + wt.name = "Fourier Saw" + wt.phase = math.linspace(0.0, 1.0, wt.sample_count) + return wt +end + +function generate(wt) + local res = {} + + for i,x in ipairs(wt.phase) do + local lv = 0 + lv = sin(2 * pi * wt.frame * x) + res[i] = lv end return res end @@ -681,7 +692,7 @@ end for (int fno = 0; fno < 4; ++fno) { - auto fr = la->evaluateScriptAtFrame(fno); + auto fr = la->getFrame(fno); REQUIRE(fr.has_value()); REQUIRE(fr->size() == 512); auto dp = 1.0 / (512 - 1); diff --git a/src/surge-xt/gui/overlays/LuaEditors.cpp b/src/surge-xt/gui/overlays/LuaEditors.cpp index 93bfa7942d2..9ef18d77b74 100644 --- a/src/surge-xt/gui/overlays/LuaEditors.cpp +++ b/src/surge-xt/gui/overlays/LuaEditors.cpp @@ -1877,6 +1877,7 @@ struct WavetableScriptControlArea : public juce::Component, break; case tag_frames_value: { + /* int currentFrame = currentFrameN->getIntValue(); int maxFrames = framesN->getIntValue(); if (currentFrame > maxFrames) @@ -1886,12 +1887,17 @@ struct WavetableScriptControlArea : public juce::Component, } overlay->osc->wavetable_formula_nframes = maxFrames; overlay->rerenderFromUIState(); + */ + overlay->setApplyEnabled(true); } break; case tag_res_value: { + /* overlay->osc->wavetable_formula_res_base = resolutionN->getIntValue(); overlay->rerenderFromUIState(); + */ + overlay->setApplyEnabled(true); } break; @@ -2018,11 +2024,10 @@ void WavetableScriptEditor::setupEvaluator() for (int i = 1; i < resi; ++i) respt *= 2; + evaluator->setStorage(storage); evaluator->setScript(mainDocument->getAllContent().toStdString()); evaluator->setResolution(respt); evaluator->setFrameCount(controlArea->framesN->getIntValue()); - - evaluator->prepare(); } void WavetableScriptEditor::applyCode() @@ -2142,7 +2147,7 @@ void WavetableScriptEditor::rerenderFromUIState() if (rm == 0) { - auto rs = evaluator->evaluateScriptAtFrame(cfr); + auto rs = evaluator->getFrame(cfr); if (rs.has_value()) { rendererComponent->points = *rs; @@ -2164,7 +2169,7 @@ void WavetableScriptEditor::rerenderFromUIState() } else { - auto rs = evaluator->evaluateScriptAtFrame(i); + auto rs = evaluator->getFrame(i); if (rs.has_value()) { rendererComponent->fsPoints.emplace_back(*rs); @@ -2208,7 +2213,7 @@ void WavetableScriptEditor::generateWavetable() wt_header wh; float *wd = nullptr; setupEvaluator(); - evaluator->constructWavetable(wh, &wd); + evaluator->populateWavetable(wh, &wd); storage->waveTableDataMutex.lock(); osc->wt.BuildWT(wd, wh, wh.flags & wtf_is_sample); osc->wavetable_display_name = evaluator->getSuggestedWavetableName();