diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..6466075 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,60 @@ +# Git +.git +.gitignore + +# Python +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg + +# Virtual environments +venv/ +env/ +ENV/ + +# IDE +.vscode/ +.idea/ +*.swp +*.swo + +# OS +.DS_Store +Thumbs.db + +# Logs +*.log + +# Documentation +docs/_build/ + +# Test coverage +.coverage +htmlcov/ + +# Temporary files +*.tmp +*.temp + +models/ + +Dockerfile* + +*.md diff --git a/.gitignore b/.gitignore index 90ab3e6..f319416 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,13 @@ **/__pycache__/ /models /Pipfile +/modules/autocog/*.so +/build/ +ftt.svg +.venv/ +modules/autocog.egg-info/ +**/__build + +*.sta.json +*.ftt.json +**/DEBUG.log diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..a65ceb2 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,6 @@ +[submodule "vendors/llama"] + path = vendors/llama + url = https://github.com/ggerganov/llama.cpp.git +[submodule "vendors/reflex"] + path = vendors/reflex + url = https://github.com/Genivia/RE-flex.git diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000..13d26cd --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,78 @@ +cmake_minimum_required(VERSION 3.18) +project(autocog) + +# Set C++ standard +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +if(NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE "Release" CACHE STRING "Build type" FORCE) +endif() +set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo") + +if(CMAKE_BUILD_TYPE STREQUAL "Debug") + add_compile_definitions(VERBOSE=1) + add_compile_definitions(DEBUG=1) +else() + add_compile_definitions(VERBOSE=0) + add_compile_definitions(NDEBUG=1) +endif() + +set(COMMON_CXX_FLAGS "-Wall -Wextra -Werror") +if(CMAKE_CXX_COMPILER_ID MATCHES "GNU|Clang") + set(CMAKE_CXX_FLAGS_DEBUG "-g -O0 -rdynamic -fno-omit-frame-pointer -Wl,--wrap=__cxa_throw ${COMMON_CXX_FLAGS}") + set(CMAKE_CXX_FLAGS_RELEASE "-O3 -DNDEBUG ${COMMON_CXX_FLAGS}") +endif() + +find_package(Python COMPONENTS Interpreter Development REQUIRED) +find_package(pybind11 QUIET) +if(NOT pybind11_FOUND) + execute_process( + COMMAND ${Python_EXECUTABLE} -c "import pybind11; print(pybind11.get_cmake_dir())" + OUTPUT_VARIABLE pybind11_DIR + OUTPUT_STRIP_TRAILING_WHITESPACE + ERROR_QUIET + ) + if(pybind11_DIR) + find_package(pybind11 REQUIRED PATHS ${pybind11_DIR}) + else() + message(FATAL_ERROR "pybind11 not found. Install with: pip install pybind11") + endif() +endif() + +find_program(REFLEX reflex REQUIRED) +find_library(REFLEX_LIB reflex REQUIRED) +find_path(REFLEX_INC + NAMES reflex/flexlexer.h + HINTS /opt/include /usr/local/include + REQUIRED +) + +find_library(LLAMA_LIB + NAMES llama + HINTS /opt/lib /opt/lib64 /usr/local/lib /usr/local/lib64 + REQUIRED +) + +find_path(LLAMA_INC + NAMES llama.h + HINTS /opt/include /usr/local/include + REQUIRED +) + +enable_testing() + +add_subdirectory(libs/autocog/utilities) + +add_subdirectory(libs/autocog/compiler/stl) +add_subdirectory(tools/stlc) +add_subdirectory(bindings/compiler-stl) + +add_subdirectory(libs/autocog/llama/xfta) +add_subdirectory(tools/xfta) +add_subdirectory(bindings/llama-xfta) + +add_subdirectory(tests/units) + +add_subdirectory(share/demos) + diff --git a/DEVEL.md b/DEVEL.md new file mode 100644 index 0000000..6baa72f --- /dev/null +++ b/DEVEL.md @@ -0,0 +1,143 @@ +# Develop Commands Cheat Sheet + +## Get models + +``` +mkdir -p models +cd models +wget -O SmolLM3-Q4_K_M.gguf https://huggingface.co/ggml-org/SmolLM3-3B-GGUF/resolve/main/SmolLM3-Q4_K_M.gguf?download=true +``` + +## Container + +### CPU (for dev) + +Build the image: +``` +docker build -t autocog:ubi -f Dockerfile.ubi . +``` + +Run one-off commands: +``` +docker run --rm -v $(pwd):/workspace -w /workspace autocog:ubi scripts/sanity-check.sh +``` + +**Recommended: Use a persistent container for development:** +```bash +# Start persistent container +docker run -d --name autocog --rm -v $(pwd):/workspace -w /workspace autocog:ubi sleep infinity + +# Execute commands in the container +docker exec autocog bash -c "cd /tmp && cmake /workspace && make install && ctest" + +# Interactive shell +docker exec -it autocog bash + +# Stop container when done +docker stop autocog +``` + +### CUDA on RHEL with Podman + +First setup CUDA for container use (3rd command is for FIPS enable machines): +```bash +curl -s -L https://nvidia.github.io/libnvidia-container/stable/rpm/nvidia-container-toolkit.repo | \ + sudo tee /etc/yum.repos.d/nvidia-container-toolkit.repo +sudo dnf --disablerepo=\* --enablerepo=nvidia-container-toolkit-experimental install -y nvidia-container-toolkit +sudo rpm -ivh --nodigest --nofiledigest /var/cache/dnf/nvidia-container-toolkit-experimental-*/packages/*.rpm +sudo nvidia-ctk cdi generate --output=/etc/cdi/nvidia.yaml +sudo chmod o+r /etc/cdi/nvidia.yaml +sudo chmod o+rx /etc/cdi +podman run --rm --device nvidia.com/gpu=all docker.io/nvidia/cuda:12.4.0-runtime-ubuntu22.04 nvidia-smi +``` + +Build and run persistent container: +```bash +podman build --device nvidia.com/gpu=all -f Dockerfile.ubi-cuda -t autocog:ubi-cuda . +podman run -d --name autocog --rm --device nvidia.com/gpu=all \ + -v $(pwd):/workspace -w /workspace autocog:ubi-cuda sleep infinity + +# Execute commands +podman exec autocog bash -c "cd /tmp && cmake /workspace && make install && ctest" +``` + +### Ubuntu with CUDA (Docker) + +```bash +docker build -t autocog:ubuntu -f Dockerfile.ubuntu . +docker run -d --name autocog --rm --gpus all \ + -v $(pwd):/workspace -w /workspace autocog:ubuntu sleep infinity +``` + +## Testing Components + +### xFTA (Finite Thoughts Automaton Executor) + +Testing the C++ utility: +```bash +python3 scripts/dump_sta_to_json.py tests/samples/mini.sta models/SmolLM3-Q4_K_M.gguf +xfta -v -m models/SmolLM3-Q4_K_M.gguf tests/samples/mini.sta.json +``` + +Testing the integration: +```bash +python3 scripts/execute_sta_with_llama_cpp.py tests/samples/mini.sta '{}' models/SmolLM3-Q4_K_M.gguf +``` + +### STLC (Structured Thoughts Language Compiler) + +Test compilation: +```bash +stlc -h +stlc tests/samples/defines.stl +stlc share/demos/story-writer/story-writer.stl +``` + +## Building and Testing + +### Build Types + +**Debug build** (with exception backtrace wrapper): +```bash +mkdir -p /tmp/autocog && cd /tmp/autocog +cmake /workspace -DCMAKE_BUILD_TYPE=Debug -DCMAKE_INSTALL_PREFIX=/opt +make install -j$(nproc) +``` + +**Release build** (optimized, default): +```bash +mkdir -p /tmp/autocog && cd /tmp/autocog +cmake /workspace -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX=/opt +make install -j$(nproc) +``` + +### Running Tests + +Run all tests: +```bash +cd /tmp/autocog +ctest --output-on-failure +``` + +Run specific test suites: +```bash +ctest -R stl_parser # All parser tests +ctest -R smoke # Smoke tests only +ctest -R stl_parser_identifiers -V # Single test with verbose output +``` + +Run in persistent container: +```bash +docker exec autocog bash -c "cd /tmp && cmake /workspace -DCMAKE_BUILD_TYPE=Release && make install -j\$(nproc) && ctest --output-on-failure" +``` + +### Python Package Testing + +Install and test the Python package: +```bash +pip install /workspace +python -c "import autocog" +python -c "import autocog.compiler.stl" +python -c "import autocog.llama.xfta" +``` + diff --git a/Dockerfile.ubi b/Dockerfile.ubi new file mode 100644 index 0000000..89df54c --- /dev/null +++ b/Dockerfile.ubi @@ -0,0 +1,61 @@ +FROM registry.access.redhat.com/ubi9/ubi:latest + +ENV PYTHONUNBUFFERED=1 + +RUN dnf --disablerepo=\* --enablerepo=ubi-\* install -y \ + gcc gcc-c++ \ + cmake \ + pkgconf-pkg-config \ + python3 \ + python3-pip \ + python3-devel \ + graphviz && \ + dnf clean all + +# Venv + +RUN python3 -m venv /opt + +ENV PATH="/opt/bin:$PATH" +ENV LD_LIBRARY_PATH="/opt/lib:/opt/lib64:/usr/local/lib:/usr/local/lib64:$LD_LIBRARY_PATH" +ENV PYTHONPATH="" + +RUN pip install --upgrade pip && \ + pip install graphviz + +# RE/Flex + +COPY vendors/reflex /tmp/reflex + +RUN cd /tmp/reflex && \ + ./build.sh && \ + mkdir -p /opt/include/reflex && \ + cp lib/*.a lib/*.so /opt/lib && \ + cp bin/reflex /opt/bin && \ + cp include/reflex/*.h /opt/include/reflex && \ + rm -rf /tmp/reflex + +# LLama.cpp + +COPY vendors/llama /tmp/llama_cpp + +RUN cd /tmp/llama_cpp && \ + cmake -B build \ + -DCMAKE_INSTALL_PREFIX=/opt \ + -DLLAMA_BUILD_COMMON=OFF \ + -DLLAMA_CUDA=OFF && \ + cmake --build build --parallel $(nproc) && \ + cmake --install build && \ + rm -rf /tmp/llama_cpp + +# Autocog + +COPY . /tmp/autocog + +RUN pip install /tmp/autocog && \ + rm -rf /tmp/autocog + +ENV PATH="/opt/lib64/python3.9/site-packages/bin/:$PATH" + +WORKDIR /workspace + diff --git a/Dockerfile.ubi-cuda b/Dockerfile.ubi-cuda new file mode 100644 index 0000000..19bab1f --- /dev/null +++ b/Dockerfile.ubi-cuda @@ -0,0 +1,60 @@ +FROM docker.io/nvidia/cuda:13.0.0-devel-ubi9 + +ENV PYTHONUNBUFFERED=1 + +RUN dnf --disablerepo=\* --enablerepo=ubi-\* install -y \ + gcc gcc-c++ \ + cmake \ + pkgconf-pkg-config \ + python3 \ + python3-pip \ + python3-devel \ + graphviz && \ + dnf clean all + +# Venv + +RUN python3 -m venv /opt + +ENV PATH="/opt/bin:$PATH" +ENV LD_LIBRARY_PATH="/opt/lib:/opt/lib64:/usr/local/lib:/usr/local/lib64:$LD_LIBRARY_PATH" +ENV PYTHONPATH="" + +RUN pip install --upgrade pip && \ + pip install graphviz + +# RE/Flex + +COPY vendors/reflex /tmp/reflex + +RUN cd /tmp/reflex && \ + ./build.sh && \ + mkdir -p /opt/include/reflex && \ + cp lib/*.a lib/*.so /opt/lib && \ + cp bin/reflex /opt/bin && \ + cp include/reflex/*.h /opt/include/reflex && \ + rm -rf /tmp/reflex + +# LLama.cpp + +COPY vendors/llama /tmp/llama_cpp + +RUN cd /tmp/llama_cpp && \ + cmake -B build \ + -DCMAKE_INSTALL_PREFIX=/opt \ + -DLLAMA_BUILD_COMMON=OFF \ + -DLLAMA_CUDA=ON && \ + cmake --build build --parallel $(nproc) && \ + cmake --install build && \ + rm -rf /tmp/llama_cpp + +# Autocog + +COPY . /tmp/autocog + +RUN pip install /tmp/autocog && \ + rm -rf /tmp/autocog + +ENV PATH="/opt/lib64/python3.9/site-packages/bin/:$PATH" + +WORKDIR /workspace diff --git a/Dockerfile.ubuntu b/Dockerfile.ubuntu new file mode 100644 index 0000000..fddfce1 --- /dev/null +++ b/Dockerfile.ubuntu @@ -0,0 +1,62 @@ +FROM ubuntu:24.04 + +ENV DEBIAN_FRONTEND=noninteractive +ENV PYTHONUNBUFFERED=1 + +RUN apt-get update && \ + apt-get install -y --no-install-recommends \ + build-essential \ + cmake \ + pkg-config \ + python3 \ + python3-pip \ + python3-dev \ + python3-venv \ + graphviz && \ + rm -rf /var/lib/apt/lists/* + +# Venv + +RUN python3 -m venv /opt + +ENV PATH="/opt/bin:$PATH" +ENV LD_LIBRARY_PATH="/opt/lib:/opt/lib64:/usr/local/lib:/usr/local/lib64:$LD_LIBRARY_PATH" +ENV PYTHONPATH="" + +RUN pip install --upgrade pip && \ + pip install graphviz + +# RE/Flex + +COPY vendors/reflex /tmp/reflex + +RUN cd /tmp/reflex && \ + ./build.sh && \ + mkdir -p /opt/include/reflex && \ + cp lib/*.a lib/*.so /opt/lib && \ + cp bin/reflex /opt/bin && \ + cp include/reflex/*.h /opt/include/reflex && \ + rm -rf /tmp/reflex + +# LLama.cpp + +COPY vendors/llama /tmp/llama_cpp + +RUN cd /tmp/llama_cpp && \ + cmake -B build \ + -DCMAKE_INSTALL_PREFIX=/opt \ + -DLLAMA_BUILD_COMMON=OFF \ + -DLLAMA_CUDA=ON && \ + cmake --build build --parallel $(nproc) && \ + cmake --install build && \ + rm -rf /tmp/llama_cpp + +# Autocog + +COPY . /workspace/autocog + +RUN pip install /workspace/autocog && \ + rm -rf /workspace/autocog + +WORKDIR /workspace + diff --git a/README.md b/README.md index a0bff45..74af711 100644 --- a/README.md +++ b/README.md @@ -18,6 +18,20 @@ We broke down the documentation into a few files: The libraries have [their own documentation](./share/library/README.md). +## Syntax Highlight + +For `gedit`, the path is based on newer version of : +``` +mkdir -p ~/.local/share/libgedit-gtksourceview-300/language-specs +cp syntax-highlight/gedit/stl.lang ~/.local/share/libgedit-gtksourceview-300/language-specs +``` + +For `vscode` (it does not seem to work but I don't use VSCode): +``` +mkdir -p ~/.vscode/extensions/stl-language +cp -r syntax-highlight/vscode/* ~/.vscode/extensions/stl-language +``` + ## Contributing Contributions are welcome! diff --git a/autocog/lm/__init__.py b/autocog/lm/__init__.py deleted file mode 100644 index 389d323..0000000 --- a/autocog/lm/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ - -from .random import RLM -from .llama import Llama -from .transformers import TfLM diff --git a/bindings/compiler-stl/CMakeLists.txt b/bindings/compiler-stl/CMakeLists.txt new file mode 100644 index 0000000..a633619 --- /dev/null +++ b/bindings/compiler-stl/CMakeLists.txt @@ -0,0 +1,30 @@ + +################################################################################################### +# Binding autocog::compiler::stl to autocog.compiler.stl + +pybind11_add_module(autocog_compiler_stl_pybind + module.cxx + convert.cxx +) +target_include_directories(autocog_compiler_stl_pybind PRIVATE + ${PROJECT_SOURCE_DIR}/libs + ${PROJECT_SOURCE_DIR}/vendors/headers +) +target_link_libraries(autocog_compiler_stl_pybind PRIVATE + autocog_compiler_stl_lib + autocog_utilities_lib +) +set_property(TARGET autocog_compiler_stl_pybind PROPERTY POSITION_INDEPENDENT_CODE ON) +set_target_properties(autocog_compiler_stl_pybind PROPERTIES + OUTPUT_NAME "stl_cxx" + PREFIX "${PYTHON_MODULE_PREFIX}" + SUFFIX "${PYTHON_MODULE_EXTENSION}" +) +install(TARGETS autocog_compiler_stl_pybind DESTINATION autocog/compiler/stl) + +add_test(NAME python_stlc_binding_smoke_test + COMMAND ${Python_EXECUTABLE} -c "import autocog.compiler.stl" + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) + +set_tests_properties(python_stlc_binding_smoke_test PROPERTIES + ENVIRONMENT "PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}:$ENV{PYTHONPATH}") diff --git a/bindings/compiler-stl/convert.cxx b/bindings/compiler-stl/convert.cxx new file mode 100644 index 0000000..db2dcef --- /dev/null +++ b/bindings/compiler-stl/convert.cxx @@ -0,0 +1,9 @@ + +#include "convert.hxx" + +namespace autocog::compiler::stl { + +// TODO IR convertion to Python (with Protobuf?) + +} + diff --git a/bindings/compiler-stl/convert.hxx b/bindings/compiler-stl/convert.hxx new file mode 100644 index 0000000..41ae270 --- /dev/null +++ b/bindings/compiler-stl/convert.hxx @@ -0,0 +1,12 @@ +#ifndef BINDINGS_COMPILER_STL_CONVERT_HXX +#define BINDINGS_COMPILER_STL_CONVERT_HXX + +#include + +namespace autocog::compiler::stl { + +// TODO IR convertion to Python (with Protobuf?) + +} + +#endif /* BINDINGS_COMPILER_STL_CONVERT_HXX */ diff --git a/bindings/compiler-stl/module.cxx b/bindings/compiler-stl/module.cxx new file mode 100644 index 0000000..11bdcf0 --- /dev/null +++ b/bindings/compiler-stl/module.cxx @@ -0,0 +1,15 @@ + +#include "convert.hxx" + +#include "autocog/compiler/stl/parser.hxx" +#include "autocog/compiler/stl/ast.hxx" + +#include +#include + +PYBIND11_MODULE(stl_cxx, module) { + module.doc() = "C++ STL parser for AutoCog"; + + // TODO +} + diff --git a/bindings/llama-xfta/CMakeLists.txt b/bindings/llama-xfta/CMakeLists.txt new file mode 100644 index 0000000..dbccff2 --- /dev/null +++ b/bindings/llama-xfta/CMakeLists.txt @@ -0,0 +1,31 @@ + +################################################################################################### +# Binding autocog::llama::xfta to autocog.llama.xfta + +pybind11_add_module(autocog_llama_xfta_pybind + module.cxx + convert.cxx +) +target_include_directories(autocog_llama_xfta_pybind PRIVATE + ${PROJECT_SOURCE_DIR}/libs + ${PROJECT_SOURCE_DIR}/vendors/headers +) +target_link_libraries(autocog_llama_xfta_pybind PRIVATE + autocog_llama_xfta_lib + autocog_utilities_lib +) +set_property(TARGET autocog_llama_xfta_pybind PROPERTY POSITION_INDEPENDENT_CODE ON) +set_target_properties(autocog_llama_xfta_pybind PROPERTIES + OUTPUT_NAME "xfta_cxx" + PREFIX "${PYTHON_MODULE_PREFIX}" + SUFFIX "${PYTHON_MODULE_EXTENSION}" +) +install(TARGETS autocog_llama_xfta_pybind DESTINATION autocog/llama/xfta) + +add_test(NAME python_xfta_binding_smoke_test + COMMAND ${Python_EXECUTABLE} -c "import autocog.llama.xfta" + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) + +set_tests_properties(python_xfta_binding_smoke_test PROPERTIES + ENVIRONMENT "PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}:$ENV{PYTHONPATH}") + diff --git a/bindings/llama-xfta/convert.cxx b/bindings/llama-xfta/convert.cxx new file mode 100644 index 0000000..7da0bf5 --- /dev/null +++ b/bindings/llama-xfta/convert.cxx @@ -0,0 +1,170 @@ + +#include "convert.hxx" + +#include "autocog/llama/xfta/manager.hxx" + +#include + +namespace autocog::llama::xfta { + +FTA convert_pydict_to_fta(ModelID const id, pybind11::dict const & pydata) { + Model & model = Manager::get_model(id); + FTA fta; + + // Get the actions dictionary from Python FTA + if (!pydata.contains("actions")) { + throw std::runtime_error("FTA dictionary missing 'actions' field"); + } + + auto py_actions = pydata["actions"].cast(); + std::map uid_to_id; + + for (auto item : py_actions) { + auto action_dict = item.cast(); + + if (!action_dict.contains("uid")) { + throw std::runtime_error("Action missing 'uid' field."); + } + auto uid = action_dict["uid"].cast(); + if (!action_dict.contains("__type__")) { + throw std::runtime_error("Action missing '__type__' field: " + uid); + } + + std::string action_type = action_dict["__type__"].cast(); + + ActionID node_id = fta.actions.size(); // Assign sequential IDs + uid_to_id[uid] = node_id; + + std::unique_ptr action; + + if (action_type == "Text") { + bool evaluate = false; + if (action_dict.contains("evaluate")) evaluate = action_dict["evaluate"].cast(); + + action = std::make_unique(node_id, uid, evaluate); + Text * text_action = static_cast(action.get()); + + auto py_tokens = action_dict["tokens"].cast(); + text_action->tokens.clear(); + for (auto token : py_tokens) { + text_action->tokens.push_back(token.cast()); + } + + } else if (action_type == "Complete") { + float threshold = action_dict["threshold"].cast(); + unsigned length = action_dict["length"].cast(); + unsigned beams = action_dict["beams"].cast(); + unsigned ahead = action_dict["ahead"].cast(); + unsigned width = action_dict["width"].cast(); + std::optional repetition = std::nullopt; + if (action_dict.contains("repetition") && !action_dict["repetition"].is_none()) repetition = action_dict["repetition"].cast(); + std::optional diversity = std::nullopt; + if (action_dict.contains("diversity") && !action_dict["diversity"].is_none()) repetition = action_dict["diversity"].cast(); + + action = std::make_unique(node_id, uid, threshold, length, beams, ahead, width, repetition, diversity); + Completion* completion_action = static_cast(action.get()); + + if (!action_dict.contains("stop")) { + throw std::runtime_error("Completion missing 'stop' field: " + uid); + } else { + auto py_stop = action_dict["stop"].cast(); + completion_action->stop.clear(); + for (auto token : py_stop) { + completion_action->stop.push_back(token.cast()); + } + } + + completion_action->vocab.mask.clear(); + completion_action->vocab.mask.reserve(model.vocab_size()); + for (auto tok_msk : action_dict["vocab"]) { + completion_action->vocab.mask.push_back(tok_msk.cast()); + } + + } else if (action_type == "Choose") { + float threshold = action_dict["threshold"].cast(); + unsigned width = action_dict["width"].cast(); + action = std::make_unique(node_id, uid, threshold, width); + Choice* choice_action = static_cast(action.get()); + + if (action_dict.contains("choices")) { + auto py_choices = action_dict["choices"].cast(); + choice_action->choices.clear(); + choice_action->choices.reserve(py_choices.size()); + + for (auto py_choice : py_choices) { + auto py_tokens = py_choice.cast(); + TokenSequence choice_tokens; + choice_tokens.reserve(py_tokens.size()); + for (auto token : py_tokens) { + choice_tokens.push_back(token.cast()); + } + choice_action->choices.push_back(std::move(choice_tokens)); + } + } + + } else { + throw std::runtime_error("Unknown action type: " + action_type); + } + + fta.actions.push_back(std::move(action)); + } + + for (auto item : py_actions) { + auto action_dict = item.cast(); + auto uid = action_dict["uid"].cast(); + + ActionID node_id = uid_to_id[uid]; + Action * action = fta.actions[node_id].get(); + + // Add successors + if (action_dict.contains("successors")) { + auto py_successors = action_dict["successors"].cast(); + for (auto successor : py_successors) { + std::string successor_uid = successor.cast(); + if (uid_to_id.find(successor_uid) == uid_to_id.end()) { + throw std::runtime_error("Unknown successor UID: " + successor_uid); + } + action->successors.push_back(uid_to_id[successor_uid]); + } + } + } + + return fta; +} + +pybind11::dict convert_ftt_to_pydict(ModelID const id, FTT const & ftt) { + [[maybe_unused]] Model & model = Manager::get_model(id); + + pybind11::dict result; + result["action"] = ftt.action; + + // Convert tokens + pybind11::list token_list; + for (TokenID token : ftt.tokens) { + token_list.append(token); + } + result["tokens"] = token_list; + + // Convert probabilities + pybind11::list logprobs_list; + for (float lpb : ftt.logprobs) { + logprobs_list.append(lpb); + } + result["logprobs"] = logprobs_list; + result["logprob"] = ftt.logprob; + result["length"] = ftt.length; + + // Convert children recursively + pybind11::list children_list; + for (const FTT& child : ftt.get_children()) { + children_list.append(convert_ftt_to_pydict(id, child)); + } + result["children"] = children_list; + + // Add metadata + result["pruned"] = ftt.pruned; + + return result; +} + +} diff --git a/bindings/llama-xfta/convert.hxx b/bindings/llama-xfta/convert.hxx new file mode 100644 index 0000000..335d762 --- /dev/null +++ b/bindings/llama-xfta/convert.hxx @@ -0,0 +1,16 @@ +#ifndef BINDINGS_LLAMA_XFTA_CONVERT_HXX +#define BINDINGS_LLAMA_XFTA_CONVERT_HXX + +#include "autocog/llama/xfta/fta.hxx" +#include "autocog/llama/xfta/ftt.hxx" + +#include + +namespace autocog::llama::xfta { + +FTA convert_pydict_to_fta(ModelID const id, pybind11::dict const & pydata); +pybind11::dict convert_ftt_to_pydict(ModelID const id, FTT const & ftt); + +} + +#endif /* BINDINGS_LLAMA_XFTA_CONVERT_HXX */ diff --git a/bindings/llama-xfta/module.cxx b/bindings/llama-xfta/module.cxx new file mode 100644 index 0000000..ce120c2 --- /dev/null +++ b/bindings/llama-xfta/module.cxx @@ -0,0 +1,154 @@ + +#include "convert.hxx" + +#include "autocog/llama/xfta/model.hxx" +#include "autocog/llama/xfta/evaluation.hxx" +#include "autocog/llama/xfta/manager.hxx" + +#include +#include +#include + +#include + +#if VERBOSE +# include +#endif +#define DEBUG_pybind_evaluate VERBOSE && 0 + +PYBIND11_MODULE(xfta_cxx, module) { + using namespace autocog::llama::xfta; + + Manager::initialize(); + + module.doc() = "AutoCog's llama.cpp integration module"; + + module.def("create", + [](std::string const & model_path, int n_ctx) { + return Manager::add_model(model_path, n_ctx); + }, + "Instantiate a GGML model with llama.cpp", + pybind11::arg("model_path"), + pybind11::arg("n_ctx") = 4096 + ); + + module.def("vocab_size", [](ModelID model) { + return Manager::get_model(model).vocab_size(); + }, "Get vocabulary size"); + + module.def("tokenize", + [](ModelID model, const std::string & text, bool add_bos, bool special) { + auto tokens = Manager::get_model(model).tokenize(text, add_bos, special); + pybind11::list result; + for (auto token : tokens) result.append(token); + return result; + }, + "Tokenize text using llama.cpp", + pybind11::arg("model"), + pybind11::arg("text"), + pybind11::arg("add_bos") = false, + pybind11::arg("special") = false + ); + + module.def("detokenize", + [](ModelID model, const pybind11::list & py_tokens, bool spec_rm, bool spec_unp) { + TokenSequence tokens; + for (auto item : py_tokens) tokens.push_back(item.cast()); + return Manager::get_model(model).detokenize(tokens, spec_rm, spec_unp); + }, + "Detokenize tokens to text", + pybind11::arg("model"), + pybind11::arg("tokens"), + pybind11::arg("spec_rm") = false, + pybind11::arg("spec_unp") = false + ); + + module.def("evaluate", + [](ModelID model, pybind11::dict const & fta_dict) { +#if DEBUG_pybind_evaluate + std::cerr << "IN evaluate (pybind): START" << std::endl; +#endif + FTA fta = convert_pydict_to_fta(model, fta_dict); +#if DEBUG_pybind_evaluate + std::cerr << "IN evaluate (pybind): FTA" << std::endl; +#endif + EvalID eval = Manager::add_eval(model, fta); +#if DEBUG_pybind_evaluate + std::cerr << "IN evaluate (pybind): EVAL" << std::endl; +#endif + Manager::advance(eval, std::nullopt); +#if DEBUG_pybind_evaluate + std::cerr << "IN evaluate (pybind): FTT" << std::endl; +#endif + FTT const & ftt = Manager::retrieve(eval); +#if DEBUG_pybind_evaluate + std::cerr << "IN evaluate (pybind): RES" << std::endl; +#endif + pybind11::dict res = convert_ftt_to_pydict(model, ftt); +#if DEBUG_pybind_evaluate + std::cerr << "IN evaluate (pybind): CLEAN" << std::endl; +#endif + Manager::rm_eval(eval); +#if DEBUG_pybind_evaluate + std::cerr << "IN evaluate (pybind): DONE" << std::endl; +#endif + return res; + }, + "Evaluate a FTA using a model and return the FTT." + ); + +#ifdef ASYNC_EXEC + module.def("instantiate", + [](ModelID model, pybind11::dict const & fta) { + FTA fta_ = convert_pydict_to_fta(model, fta); + // TODO + }, + "Instantiate a FTA using a model and return the EvalID.", + pybind11::arg("model"), + pybind11::arg("fta") + ); + + module.def("advance", + [](EvalID eval, std::optional max_token_eval) { + // TODO + }, + "Advance a FTA evaluation with an optional (soft) limit on the number of token evaluation, return the number of evaluated tokens.", + pybind11::arg("eval"), + pybind11::arg("max_token_eval") = std::nullopt + ); + + module.def("advance_bg", + [](EvalID eval, std::optional max_token_eval) { + // TODO + }, + "Signal an evaluation to run in the background, return nothing immediately.", + pybind11::arg("eval"), + pybind11::arg("max_token_eval") = std::nullopt + ); + + module.def("finished", + [](EvalID eval) { + // TODO + }, + "Check if a FTA evaluation is finished.", + pybind11::arg("eval") + ); + + module.def("retrieve", + [](EvalID eval) { + // TODO + }, + "Retrieve the FTT being generated by an evaluation.", + pybind11::arg("eval") + ); + + module.def("release", + [](EvalID eval) { + // TODO + }, + "Wait for the evaluation to finish, retrieve the generated FTT, then remove the evaluation.", + pybind11::arg("eval") + ); +#endif /* ASYNC_EXEC */ +} + diff --git a/libs/autocog/compiler/stl/CMakeLists.txt b/libs/autocog/compiler/stl/CMakeLists.txt new file mode 100644 index 0000000..2a921d4 --- /dev/null +++ b/libs/autocog/compiler/stl/CMakeLists.txt @@ -0,0 +1,41 @@ + +# Generate lexer from .l file +add_custom_command( + OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/autocog_compiler_stl_lexer.cxx ${CMAKE_CURRENT_BINARY_DIR}/autocog_compiler_stl_lexer.hxx + COMMAND ${REFLEX} --header-file=${CMAKE_CURRENT_BINARY_DIR}/autocog_compiler_stl_lexer.hxx + --outfile=${CMAKE_CURRENT_BINARY_DIR}/autocog_compiler_stl_lexer.cxx + ${CMAKE_CURRENT_SOURCE_DIR}/lexer.l + DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/lexer.l + COMMENT "Generating lexer with RE/flex" +) + +add_library(autocog_compiler_stl_lib STATIC + driver.cxx diagnostic.cxx + token.cxx ast.cxx ast/expr.cxx ast/tostring.cxx + parser.cxx parser-state.cxx + parser/annotate.cxx parser/call.cxx parser/channel.cxx + parser/define.cxx parser/alias.cxx parser/expression.cxx + parser/field.cxx parser/fieldref.cxx parser/flow.cxx + parser/format.cxx parser/import.cxx parser/kwarg.cxx + parser/link.cxx parser/path.cxx parser/program.cxx + parser/prompt.cxx parser/objectref.cxx parser/record.cxx + parser/return.cxx parser/search.cxx parser/clauses.cxx + parser/struct.cxx parser/assign.cxx + symbols.cxx symbol-table.cxx symbol-scanner.cxx + evaluate.cxx instance-scanner.cxx + ${CMAKE_CURRENT_BINARY_DIR}/autocog_compiler_stl_lexer.cxx +) + +target_include_directories(autocog_compiler_stl_lib PUBLIC + ${PROJECT_SOURCE_DIR}/libs + ${PROJECT_SOURCE_DIR}/vendors/headers + ${CMAKE_CURRENT_BINARY_DIR} + ${REFLEX_INC} +) + +target_link_libraries(autocog_compiler_stl_lib PUBLIC + ${REFLEX_LIB} +) + +set_property(TARGET autocog_compiler_stl_lib PROPERTY POSITION_INDEPENDENT_CODE ON) + diff --git a/libs/autocog/compiler/stl/ast.cxx b/libs/autocog/compiler/stl/ast.cxx new file mode 100644 index 0000000..1257706 --- /dev/null +++ b/libs/autocog/compiler/stl/ast.cxx @@ -0,0 +1,43 @@ + +#include "autocog/compiler/stl/ast.hxx" + +#include +#include + +#if VERBOSE +# include +#endif + +namespace autocog::compiler::stl::ast { + +const std::unordered_map tags = { +#define X(etag,stag) {stag,Tag::etag}, +#include "autocog/compiler/stl/ast/nodes.def" +}; +const std::unordered_map rtags = { +#define X(etag,stag) {Tag::etag,stag}, +#include "autocog/compiler/stl/ast/nodes.def" +}; + +std::string tag2str(Tag tag) { + auto it = rtags.find(tag); + if (it == rtags.end()) { + std::ostringstream oss; + oss << "Unknown ast::Tag for tag2str(" << (long)tag << ")!"; + throw std::runtime_error(oss.str()); + } + return it->second; +} + +Tag str2tag(std::string str) { + auto it = tags.find(str); + if (it == tags.end()) { + std::ostringstream oss; + oss << "Unrecognized ast::Tag for str2tag(" << str << ")!"; + throw std::runtime_error(oss.str()); + } + return it->second; +} + +} + diff --git a/libs/autocog/compiler/stl/ast.hxx b/libs/autocog/compiler/stl/ast.hxx new file mode 100644 index 0000000..ee4c1ac --- /dev/null +++ b/libs/autocog/compiler/stl/ast.hxx @@ -0,0 +1,145 @@ +#ifndef AUTOCOG_COMPILER_STL_AST_HXX +#define AUTOCOG_COMPILER_STL_AST_HXX + +#include "autocog/compiler/stl/token.hxx" +#include "autocog/compiler/stl/location.hxx" + +#include +#include +#include +#include +#include +#include +#include + +namespace autocog::compiler::stl { + +class Lexer; +class Parser; +struct Diagnostic; + +} + +namespace autocog::compiler::stl::ast { + +enum class Tag { +#define X(etag,stag) etag, +#include "autocog/compiler/stl/ast/nodes.def" +}; + +extern const std::unordered_map tags; +extern const std::unordered_map rtags; + +std::string tag2str(Tag tag); +Tag str2tag(std::string str); + +template +struct Data; + +template +struct Node { + static constexpr Tag tag = tagT; + + Node() : data() {} + + template + Node(Args&&... args) : data{std::forward(args)...} {} + + template + static std::unique_ptr make(Args&&... args) { + return std::make_unique(std::forward(args)...); + } + + Node(const Node&) = delete; + Node & operator=(const Node &) = delete; + Node(Node &&) = delete; + Node & operator=(Node &&) = delete; + + std::optional location; + Data data; + + template + void traverse(TraversalT & traversal) const { + traversal.pre(*this); + if (!traversal.shortcut(*this)) + traverse_children(traversal); + traversal.post(*this); + } + + template + void traverse_children(TraversalT & traversal) const; +}; + +} + +#include "autocog/compiler/stl/ast/traits.hxx" + +namespace autocog::compiler::stl::ast { + +template +std::enable_if_t> traverse_generic( + TraversalT & traversal, T const & container +) { + if constexpr (is_node_v) { + container.traverse(traversal); + } else if constexpr (is_pnode_v) { + if (container) container->traverse(traversal); + } else if constexpr (is_onode_v) { + if (container) container->traverse(traversal); + } else if constexpr (is_nodes_v) { + for (auto const & node : container) + node.traverse(traversal); + } else if constexpr (is_pnodes_v) { + for (auto const & pnode : container) + if (pnode) + pnode->traverse(traversal); + } else if constexpr (is_variant_v) { + std::visit([&traversal](auto const & node) { + if constexpr (is_node_v>) { + node.traverse(traversal); + } + }, container); + } else if constexpr (is_variants_v) { + for (auto const & vnode : container) + std::visit([&traversal](auto const & node) { + if constexpr (is_node_v>) { + node.traverse(traversal); + } + }, vnode); + } +} + +} + +#include "autocog/compiler/stl/ast/macros.hxx" + +#include "autocog/compiler/stl/ast/expr.hxx" +#include "autocog/compiler/stl/ast/path.hxx" +#include "autocog/compiler/stl/ast/channel.hxx" +#include "autocog/compiler/stl/ast/return.hxx" +#include "autocog/compiler/stl/ast/search.hxx" +#include "autocog/compiler/stl/ast/struct.hxx" +#include "autocog/compiler/stl/ast/annot.hxx" +#include "autocog/compiler/stl/ast/flow.hxx" +#include "autocog/compiler/stl/ast/define.hxx" +#include "autocog/compiler/stl/ast/prompt.hxx" +#include "autocog/compiler/stl/ast/record.hxx" +#include "autocog/compiler/stl/ast/program.hxx" + +namespace autocog::compiler::stl::ast { + +#define X(etag,stag) using etag = NODE(etag); +#include "autocog/compiler/stl/ast/nodes.def" + +#define X(etag,stag) using etag##Ptr = const etag*; +#include "autocog/compiler/stl/ast/nodes.def" + +using NodePtr = std::variant< + std::monostate +#define X(etag,stag) , etag##Ptr +#include "autocog/compiler/stl/ast/nodes.def" +>; + +} + +#endif // AUTOCOG_COMPILER_STL_AST_HXX diff --git a/libs/autocog/compiler/stl/ast/annot.hxx b/libs/autocog/compiler/stl/ast/annot.hxx new file mode 100644 index 0000000..f7bb67b --- /dev/null +++ b/libs/autocog/compiler/stl/ast/annot.hxx @@ -0,0 +1,22 @@ +#ifndef AUTOCOG_COMPILER_STL_AST_ANNOT_HXX +#define AUTOCOG_COMPILER_STL_AST_ANNOT_HXX + +namespace autocog::compiler::stl::ast { + +DATA(Annotation) { + ONODE(Path) path; + NODE(Expression) description; +}; +TRAVERSE_CHILDREN(Annotation, path, description) + +DATA(Annotate) { + bool single_statement; //< Keyword followed by single element instead of block + NODES(Annotation) annotations; +}; +TRAVERSE_CHILDREN(Annotate, annotations) + + + +} + +#endif // AUTOCOG_COMPILER_STL_AST_ANNOT_HXX diff --git a/libs/autocog/compiler/stl/ast/channel.hxx b/libs/autocog/compiler/stl/ast/channel.hxx new file mode 100644 index 0000000..6827311 --- /dev/null +++ b/libs/autocog/compiler/stl/ast/channel.hxx @@ -0,0 +1,60 @@ +#ifndef AUTOCOG_COMPILER_STL_AST_CHANNEL_HXX +#define AUTOCOG_COMPILER_STL_AST_CHANNEL_HXX + +namespace autocog::compiler::stl::ast { + +DATA(Bind) { + NODE(Path) source; + ONODE(Path) target; // if missing then '_' +}; +TRAVERSE_CHILDREN(Bind, source, target) + +DATA(Ravel) { + ONODE(Expression) depth; // if missing then 1 + ONODE(Path) target; // if missing then '_' +}; +TRAVERSE_CHILDREN(Ravel, depth, target) + +DATA(Wrap) { + ONODE(Path) target; // if missing then '_' +}; +TRAVERSE_CHILDREN(Wrap, target) + +DATA(Prune) { + NODE(Path) target; +}; +TRAVERSE_CHILDREN(Prune, target) + +DATA(Mapped) { + ONODE(Path) target; // if missing then '_' +}; +TRAVERSE_CHILDREN(Mapped, target) + +DATA(Kwarg) { + NODE(Identifier) name; + VARIANT(FieldRef, Path, Expression) source; // "use" (a field), "get" (an input), "is" (a compile-time value), + VARIANTS(Bind, Ravel, Wrap, Prune, Mapped) clauses; +}; +TRAVERSE_CHILDREN(Kwarg, name, source, clauses) + +DATA(Call) { + NODE(ObjectRef) entry; //< Could resolve to a python symbol + NODES(Kwarg) arguments; +}; +TRAVERSE_CHILDREN(Call, entry, arguments) + +DATA(Link) { + NODE(Path) target; + VARIANT(FieldRef, Path, Expression, Call) source; // "use" (a field), "get" (an input), "is" (a compile-time value), or "call" (a callable) + VARIANTS(Bind, Ravel, Prune, Wrap) clauses; +}; +TRAVERSE_CHILDREN(Link, target, source, clauses) + +DATA(Channel) { + NODES(Link) links; +}; +TRAVERSE_CHILDREN(Channel, links) + +} + +#endif // AUTOCOG_COMPILER_STL_AST_CHANNEL_HXX diff --git a/libs/autocog/compiler/stl/ast/define.hxx b/libs/autocog/compiler/stl/ast/define.hxx new file mode 100644 index 0000000..b5def78 --- /dev/null +++ b/libs/autocog/compiler/stl/ast/define.hxx @@ -0,0 +1,15 @@ +#ifndef AUTOCOG_COMPILER_STL_AST_DEFINE_HXX +#define AUTOCOG_COMPILER_STL_AST_DEFINE_HXX + +namespace autocog::compiler::stl::ast { + +DATA(Define) { + bool is_argument; + NODE(Identifier) name; + ONODE(Expression) init; +}; +TRAVERSE_CHILDREN(Define, name, init) + +} + +#endif // AUTOCOG_COMPILER_STL_AST_DEFINE_HXX diff --git a/libs/autocog/compiler/stl/ast/expr.cxx b/libs/autocog/compiler/stl/ast/expr.cxx new file mode 100644 index 0000000..bb1d212 --- /dev/null +++ b/libs/autocog/compiler/stl/ast/expr.cxx @@ -0,0 +1,29 @@ + +#include "autocog/compiler/stl/ast.hxx" + +namespace autocog::compiler::stl::ast { + +std::string opKindToString(ast::OpKind op) { + switch (op) { + case ast::OpKind::NOP: return "NOP"; + case ast::OpKind::Not: return "Not"; + case ast::OpKind::Neg: return "Neg"; + case ast::OpKind::Add: return "Add"; + case ast::OpKind::Sub: return "Sub"; + case ast::OpKind::Mul: return "Mul"; + case ast::OpKind::Div: return "Div"; + case ast::OpKind::Mod: return "Mod"; + case ast::OpKind::And: return "And"; + case ast::OpKind::Or: return "Or"; + case ast::OpKind::Lt: return "Lt"; + case ast::OpKind::Gt: return "Gt"; + case ast::OpKind::Lte: return "Lte"; + case ast::OpKind::Gte: return "Gte"; + case ast::OpKind::Eq: return "Eq"; + case ast::OpKind::Neq: return "Neq"; + default: return "Unknown(" + std::to_string(static_cast(op)) + ")"; + } +} + +} + diff --git a/libs/autocog/compiler/stl/ast/expr.hxx b/libs/autocog/compiler/stl/ast/expr.hxx new file mode 100644 index 0000000..4a34840 --- /dev/null +++ b/libs/autocog/compiler/stl/ast/expr.hxx @@ -0,0 +1,73 @@ +#ifndef AUTOCOG_COMPILER_STL_AST_EXPR_HXX +#define AUTOCOG_COMPILER_STL_AST_EXPR_HXX + +namespace autocog::compiler::stl::ast { + +enum class OpKind { + NOP, // Not an operator + Not, Neg, // Unary + Add, Sub, Mul, Div, Mod, // Arithmetic + And, Or, // Logical + Lt, Gt, Lte, Gte, Eq, Neq // Comparison +}; +std::string opKindToString(ast::OpKind op); + +DATA(Identifier) { + std::string name; +}; +TRAVERSE_CHILDREN_EMPTY(Identifier) + +DATA(Integer) { + int value; +}; +TRAVERSE_CHILDREN_EMPTY(Integer) + +DATA(Float) { + float value; +}; +TRAVERSE_CHILDREN_EMPTY(Float) + +DATA(Boolean) { + bool value; +}; +TRAVERSE_CHILDREN_EMPTY(Boolean) + +DATA(String) { + std::string value; + bool is_format{false}; +}; +TRAVERSE_CHILDREN_EMPTY(String) + +DATA(Unary) { + OpKind kind; + PNODE(Expression) operand; +}; +TRAVERSE_CHILDREN(Unary, operand) + +DATA(Binary) { + OpKind kind; + PNODE(Expression) lhs; + PNODE(Expression) rhs; +}; +TRAVERSE_CHILDREN(Binary, lhs, rhs) + +DATA(Conditional) { + PNODE(Expression) cond; + PNODE(Expression) e_true; + PNODE(Expression) e_false; +}; +TRAVERSE_CHILDREN(Conditional, cond, e_true, e_false) + +DATA(Parenthesis) { + PNODE(Expression) expr; +}; +TRAVERSE_CHILDREN(Parenthesis, expr) + +DATA(Expression) { + VARIANT(Identifier, Integer, Float, Boolean, String, Unary, Binary, Conditional, Parenthesis) expr; +}; +TRAVERSE_CHILDREN(Expression, expr) + +} + +#endif // AUTOCOG_COMPILER_STL_AST_EXPR_HXX diff --git a/libs/autocog/compiler/stl/ast/flow.hxx b/libs/autocog/compiler/stl/ast/flow.hxx new file mode 100644 index 0000000..6317b3a --- /dev/null +++ b/libs/autocog/compiler/stl/ast/flow.hxx @@ -0,0 +1,31 @@ +#ifndef AUTOCOG_COMPILER_STL_AST_FLOW_HXX +#define AUTOCOG_COMPILER_STL_AST_FLOW_HXX + +namespace autocog::compiler::stl::ast { + +DATA(Edge) { + NODE(ObjectRef) prompt; + ONODE(Expression) limit; + ONODE(Expression) label; +}; +TRAVERSE_CHILDREN(Edge, prompt, limit, label) + +DATA(Flow) { + bool short_form; + NODES(Edge) edges; +}; +TRAVERSE_CHILDREN(Flow, edges) + +/** + * Short form: + * `flow my_prompt;` + * `flow my_prompt[3] as "XXX";` + * is equivalent to: + * `flow { my_prompt; }` + * `flow { my_prompt[3] as "XXX"; }` + * + */ + +} + +#endif // AUTOCOG_COMPILER_STL_AST_FLOW_HXX diff --git a/libs/autocog/compiler/stl/ast/macros.hxx b/libs/autocog/compiler/stl/ast/macros.hxx new file mode 100644 index 0000000..4a34a81 --- /dev/null +++ b/libs/autocog/compiler/stl/ast/macros.hxx @@ -0,0 +1,61 @@ + +#define EXPAND(...) __VA_ARGS__ +#define FOR_EACH_1(what, x) what(x) +#define FOR_EACH_2(what, x, ...) what(x), EXPAND(FOR_EACH_1(what, __VA_ARGS__)) +#define FOR_EACH_3(what, x, ...) what(x), EXPAND(FOR_EACH_2(what, __VA_ARGS__)) +#define FOR_EACH_4(what, x, ...) what(x), EXPAND(FOR_EACH_3(what, __VA_ARGS__)) +#define FOR_EACH_5(what, x, ...) what(x), EXPAND(FOR_EACH_4(what, __VA_ARGS__)) +#define FOR_EACH_6(what, x, ...) what(x), EXPAND(FOR_EACH_5(what, __VA_ARGS__)) +#define FOR_EACH_7(what, x, ...) what(x), EXPAND(FOR_EACH_6(what, __VA_ARGS__)) +#define FOR_EACH_8(what, x, ...) what(x), EXPAND(FOR_EACH_7(what, __VA_ARGS__)) +#define FOR_EACH_9(what, x, ...) what(x), EXPAND(FOR_EACH_8(what, __VA_ARGS__)) + +#define GET_MACRO(_1,_2,_3,_4,_5,_6,_7,_8,_9,NAME,...) NAME +#define FOR_EACH(action, ...) \ + GET_MACRO(__VA_ARGS__, FOR_EACH_9, FOR_EACH_8, FOR_EACH_7, FOR_EACH_6, FOR_EACH_5, \ + FOR_EACH_4, FOR_EACH_3, FOR_EACH_2, FOR_EACH_1)(action, __VA_ARGS__) + +#define TAG_PREFIX(x) Tag::x +#define FOR_EACH_TAG(...) FOR_EACH(TAG_PREFIX, __VA_ARGS__) + +#define NODE(tag) ::autocog::compiler::stl::ast::node_t +#define PNODE(tag) ::autocog::compiler::stl::ast::pnode_t +#define ONODE(tag) ::autocog::compiler::stl::ast::onode_t +#define NODES(tag) ::autocog::compiler::stl::ast::nodes_t +#define PNODES(tag) ::autocog::compiler::stl::ast::pnodes_t +#define VARIANT(...) ::autocog::compiler::stl::ast::variant_t +#define VARIANTS(...) ::autocog::compiler::stl::ast::variants_t + +#define DATA(tag) template <> struct Data< Tag::tag > + +#define EXPAND_STMT(...) __VA_ARGS__ +#define FOR_EACH_STMT_1(what, x) what(x); +#define FOR_EACH_STMT_2(what, x, ...) what(x); EXPAND_STMT(FOR_EACH_STMT_1(what, __VA_ARGS__)) +#define FOR_EACH_STMT_3(what, x, ...) what(x); EXPAND_STMT(FOR_EACH_STMT_2(what, __VA_ARGS__)) +#define FOR_EACH_STMT_4(what, x, ...) what(x); EXPAND_STMT(FOR_EACH_STMT_3(what, __VA_ARGS__)) +#define FOR_EACH_STMT_5(what, x, ...) what(x); EXPAND_STMT(FOR_EACH_STMT_4(what, __VA_ARGS__)) +#define FOR_EACH_STMT_6(what, x, ...) what(x); EXPAND_STMT(FOR_EACH_STMT_5(what, __VA_ARGS__)) +#define FOR_EACH_STMT_7(what, x, ...) what(x); EXPAND_STMT(FOR_EACH_STMT_6(what, __VA_ARGS__)) +#define FOR_EACH_STMT_8(what, x, ...) what(x); EXPAND_STMT(FOR_EACH_STMT_7(what, __VA_ARGS__)) +#define FOR_EACH_STMT_9(what, x, ...) what(x); EXPAND_STMT(FOR_EACH_STMT_8(what, __VA_ARGS__)) + +#define GET_MACRO_STMT(_1,_2,_3,_4,_5,_6,_7,_8,_9,NAME,...) NAME +#define FOR_EACH_STMT(action, ...) \ + GET_MACRO_STMT(__VA_ARGS__, FOR_EACH_STMT_9, FOR_EACH_STMT_8, FOR_EACH_STMT_7, \ + FOR_EACH_STMT_6, FOR_EACH_STMT_5, FOR_EACH_STMT_4, \ + FOR_EACH_STMT_3, FOR_EACH_STMT_2, FOR_EACH_STMT_1)(action, __VA_ARGS__) + +#define TRAVERSE_FIELD(field) traverse_generic(traversal, data.field); + +#define TRAVERSE_CHILDREN(tag, ...) \ + template <> \ + template \ + void Node::traverse_children(TraversalT & traversal) const { \ + FOR_EACH_STMT(TRAVERSE_FIELD, __VA_ARGS__) \ + } + +#define TRAVERSE_CHILDREN_EMPTY(tag) \ + template <> \ + template \ + void Node::traverse_children(TraversalT &) const {} + diff --git a/libs/autocog/compiler/stl/ast/nodes.def b/libs/autocog/compiler/stl/ast/nodes.def new file mode 100644 index 0000000..5af999d --- /dev/null +++ b/libs/autocog/compiler/stl/ast/nodes.def @@ -0,0 +1,50 @@ +#ifndef X +#error "X macro not defined" +#endif + +X(Program,"Program") +X(Import,"Import") +X(Enum,"Enum") +X(Choice,"Choice") +X(Annotate,"Annotate") +X(Annotation,"Annotation") +X(Define,"Define") +X(Flow,"Flow") +X(Edge,"Edge") +X(Struct,"Struct") +X(Field,"Field") +X(Search,"Search") +X(Param,"Param") +X(Retfield,"Retfield") +X(Return,"Return") +X(Record,"Record") +X(FormatRef,"FormatRef") +X(Text,"Text") +X(Prompt,"Prompt") +X(FieldRef,"FieldRef") +X(ObjectRef,"ObjectRef") +X(Channel,"Channel") +X(Link,"Link") +X(Bind,"Bind") +X(Ravel,"Ravel") +X(Mapped,"Mapped") +X(Prune,"Prune") +X(Wrap,"Wrap") +X(Call,"Call") +X(Kwarg,"Kwarg") +X(Path,"Path") +X(Step,"Step") +X(Expression,"Expression") +X(Identifier,"Identifier") +X(Integer,"Integer") +X(Float,"Float") +X(Boolean,"Boolean") +X(String,"String") +X(Unary,"Unary") +X(Binary,"Binary") +X(Conditional,"Conditional") +X(Parenthesis,"Parenthesis") +X(Assign,"Assign") +X(Alias,"Alias") + +#undef X diff --git a/libs/autocog/compiler/stl/ast/path.hxx b/libs/autocog/compiler/stl/ast/path.hxx new file mode 100644 index 0000000..600d869 --- /dev/null +++ b/libs/autocog/compiler/stl/ast/path.hxx @@ -0,0 +1,39 @@ +#ifndef AUTOCOG_COMPILER_STL_AST_PATH_HXX +#define AUTOCOG_COMPILER_STL_AST_PATH_HXX + +namespace autocog::compiler::stl::ast { + +DATA(Step) { + NODE(Identifier) field; + ONODE(Expression) lower; + ONODE(Expression) upper; + bool is_range; +}; +TRAVERSE_CHILDREN(Step, field, lower, upper) + +DATA(Path) { + NODES(Step) steps; +}; +TRAVERSE_CHILDREN(Path, steps) + +DATA(Assign) { + NODE(Identifier) argument; + NODE(Expression) value; +}; +TRAVERSE_CHILDREN(Assign, argument, value) + +DATA(ObjectRef) { + NODE(Identifier) name; + NODES(Assign) config; +}; +TRAVERSE_CHILDREN(ObjectRef, name, config) + +DATA(FieldRef) { + ONODE(ObjectRef) prompt; + NODE(Path) field; +}; +TRAVERSE_CHILDREN(FieldRef, prompt, field) + +} + +#endif // AUTOCOG_COMPILER_STL_AST_PATH_HXX diff --git a/libs/autocog/compiler/stl/ast/printer.hxx b/libs/autocog/compiler/stl/ast/printer.hxx new file mode 100644 index 0000000..dded724 --- /dev/null +++ b/libs/autocog/compiler/stl/ast/printer.hxx @@ -0,0 +1,44 @@ +#ifndef AUTOCOG_COMPILER_STL_AST_PRINTER_HXX +#define AUTOCOG_COMPILER_STL_AST_PRINTER_HXX + +#include "autocog/compiler/stl/ast.hxx" + +namespace autocog::compiler::stl::ast { + +struct TagTreeTraversal { + std::ostream & out; + std::string prefix; + std::string indent; + int depth; + + TagTreeTraversal(std::ostream & out_, std::string prefix_="", std::string indent_=" ") + : out(out_), prefix(prefix_), indent(indent_), depth(0) {} + + template + void pre(Node const &) { + out << prefix; + for (int i = 0; i < depth; ++i) out << indent; + out << tag2str(T) << "\n"; + ++depth; + } + + template + void post(Node const &) { + --depth; + } + + template + bool shortcut(Node const &) { + return false; + } +}; + +template +void printTagTree(Node const & node, std::ostream & out = std::cout, std::string prefix = "", std::string indent = " ") { + TagTreeTraversal traversal(out, prefix, indent); + node.traverse(traversal); +} + +} + +#endif /* AUTOCOG_COMPILER_STL_AST_PRINTER_HXX */ diff --git a/libs/autocog/compiler/stl/ast/program.hxx b/libs/autocog/compiler/stl/ast/program.hxx new file mode 100644 index 0000000..164e577 --- /dev/null +++ b/libs/autocog/compiler/stl/ast/program.hxx @@ -0,0 +1,28 @@ +#ifndef AUTOCOG_COMPILER_STL_AST_PROGRAM_HXX +#define AUTOCOG_COMPILER_STL_AST_PROGRAM_HXX + +namespace autocog::compiler::stl::ast { + +DATA(Alias) { + NODE(ObjectRef) target; + ONODE(Identifier) alias; + bool is_export; +}; +TRAVERSE_CHILDREN(Alias, target, alias) + +DATA(Import) { + std::string file; + NODES(Alias) targets; +}; +TRAVERSE_CHILDREN(Import, targets) + +DATA(Program) { + std::string filename; + int fid; + VARIANTS(Import, Alias, Define, Annotate, Search, Record, Prompt) statements{}; +}; +TRAVERSE_CHILDREN(Program, statements) + +} + +#endif // AUTOCOG_COMPILER_STL_AST_PROGRAM_HXX diff --git a/libs/autocog/compiler/stl/ast/prompt.hxx b/libs/autocog/compiler/stl/ast/prompt.hxx new file mode 100644 index 0000000..eac487e --- /dev/null +++ b/libs/autocog/compiler/stl/ast/prompt.hxx @@ -0,0 +1,15 @@ +#ifndef AUTOCOG_COMPILER_STL_AST_PROMPT_HXX +#define AUTOCOG_COMPILER_STL_AST_PROMPT_HXX + +namespace autocog::compiler::stl::ast { + +DATA(Prompt) { + NODE(Identifier) name; + ONODE(Struct) fields; + VARIANTS(Define, Annotate, Search, Alias, Channel, Flow, Return) constructs; +}; +TRAVERSE_CHILDREN(Prompt, name, fields, constructs) + +} + +#endif // AUTOCOG_COMPILER_STL_AST_PROMPT_HXX diff --git a/libs/autocog/compiler/stl/ast/record.hxx b/libs/autocog/compiler/stl/ast/record.hxx new file mode 100644 index 0000000..eb96528 --- /dev/null +++ b/libs/autocog/compiler/stl/ast/record.hxx @@ -0,0 +1,15 @@ +#ifndef AUTOCOG_COMPILER_STL_AST_RECORD_HXX +#define AUTOCOG_COMPILER_STL_AST_RECORD_HXX + +namespace autocog::compiler::stl::ast { + +DATA(Record) { + NODE(Identifier) name; + VARIANT(Struct, FormatRef) record; + VARIANTS(Define, Annotate, Search, Alias) constructs; +}; +TRAVERSE_CHILDREN(Record, name, record, constructs) + +} + +#endif // AUTOCOG_COMPILER_STL_AST_RECORD_HXX diff --git a/libs/autocog/compiler/stl/ast/return.hxx b/libs/autocog/compiler/stl/ast/return.hxx new file mode 100644 index 0000000..8da43d7 --- /dev/null +++ b/libs/autocog/compiler/stl/ast/return.hxx @@ -0,0 +1,41 @@ +#ifndef AUTOCOG_COMPILER_STL_AST_RETURN_HXX +#define AUTOCOG_COMPILER_STL_AST_RETURN_HXX + +namespace autocog::compiler::stl::ast { + +DATA(Retfield) { + ONODE(Expression) alias; // if missing implies `field.steps[-1].field` (meaning "C" for path "A[0:2].B.C[3]") + VARIANT(Path, Expression) source; + VARIANTS(Bind, Ravel, Wrap, Prune) clauses; +}; +TRAVERSE_CHILDREN(Retfield, alias, source, clauses) + +DATA(Return) { + ONODE(Expression) label; + bool short_form; + NODES(Retfield) fields; +}; +TRAVERSE_CHILDREN(Return, label, fields) + +/** + * Block form does not accept "_" as a field alias + * + * Short form: + * `return use field.field;` + * `return "label" use field.field;` + * is equivalent to: + * `return { "_" use field.field; }` + * `return "label" { "_" use field.field; }` + * + * Empty return is also possible: + * `return;` + * `return "label";` + * and is equivalent to: + * `return {}` + * `return "label" {}` + */ + + +} + +#endif // AUTOCOG_COMPILER_STL_AST_RETURN_HXX diff --git a/libs/autocog/compiler/stl/ast/search.hxx b/libs/autocog/compiler/stl/ast/search.hxx new file mode 100644 index 0000000..ddee999 --- /dev/null +++ b/libs/autocog/compiler/stl/ast/search.hxx @@ -0,0 +1,19 @@ +#ifndef AUTOCOG_COMPILER_STL_AST_SEARCH_HXX +#define AUTOCOG_COMPILER_STL_AST_SEARCH_HXX + +namespace autocog::compiler::stl::ast { + +DATA(Param) { + NODES(Identifier) locator; + NODE(Expression) value; +}; +TRAVERSE_CHILDREN(Param, locator, value) + +DATA(Search) { + NODES(Param) params; +}; +TRAVERSE_CHILDREN(Search, params) + +} + +#endif // AUTOCOG_COMPILER_STL_AST_SEARCH_HXX diff --git a/libs/autocog/compiler/stl/ast/struct.hxx b/libs/autocog/compiler/stl/ast/struct.hxx new file mode 100644 index 0000000..88ee718 --- /dev/null +++ b/libs/autocog/compiler/stl/ast/struct.hxx @@ -0,0 +1,46 @@ +#ifndef AUTOCOG_COMPILER_STL_AST_STRUCT_HXX +#define AUTOCOG_COMPILER_STL_AST_STRUCT_HXX + +namespace autocog::compiler::stl::ast { + +DATA(Enum) { + NODES(String) enumerators; +}; +TRAVERSE_CHILDREN(Enum, enumerators) + +enum class ChoiceKind { Repeat, Select }; + +DATA(Choice) { + ChoiceKind mode; + NODE(Path) source; +}; +TRAVERSE_CHILDREN(Choice, source) + +DATA(Text) { + // TODO vocab definition +}; +TRAVERSE_CHILDREN_EMPTY(Text) + +DATA(FormatRef) { + VARIANT(Identifier, Text, Enum, Choice) type; + NODES(Expression) args; + NODES(Assign) kwargs; +}; +TRAVERSE_CHILDREN(FormatRef, type, args, kwargs) + +DATA(Struct) { + PNODES(Field) fields; +}; +TRAVERSE_CHILDREN(Struct, fields) + +DATA(Field) { + NODE(Identifier) name; + ONODE(Expression) lower; + ONODE(Expression) upper; + VARIANT(FormatRef, Struct) type; +}; +TRAVERSE_CHILDREN(Field, name, lower, upper, type) + +} + +#endif // AUTOCOG_COMPILER_STL_AST_STRUCT_HXX diff --git a/libs/autocog/compiler/stl/ast/tostring.cxx b/libs/autocog/compiler/stl/ast/tostring.cxx new file mode 100644 index 0000000..494e097 --- /dev/null +++ b/libs/autocog/compiler/stl/ast/tostring.cxx @@ -0,0 +1,191 @@ + +#include "autocog/compiler/stl/ast/tostring.hxx" + +namespace autocog::compiler::stl::ast { + +std::string toString(std::monostate const &) { + return "null"; +} + +std::string toString(Identifier const & node) { + return node.data.name; +} + +std::string toString(Integer const & node) { + return std::to_string(node.data.value); +} + +std::string toString(Float const & node) { + return std::to_string(node.data.value); +} + +std::string toString(Boolean const & node) { + return node.data.value ? "true" : "false"; +} + +std::string toString(String const & node) { + std::stringstream ss; + if (node.data.is_format) { + ss << "f\"" << node.data.value << "\""; + } else { + ss << "\"" << node.data.value << "\""; + } + return ss.str(); +} + +std::string toString(Unary const & node) { + std::stringstream ss; + ss << "(" << opKindToString(node.data.kind) << " " << toString_generic(node.data.operand) << ")"; + return ss.str(); +} + +std::string toString(Binary const & node) { + std::stringstream ss; + ss << "(" << toString_generic(node.data.lhs) << " " << opKindToString(node.data.kind) << " " << toString_generic(node.data.rhs) << ")"; + return ss.str(); +} + +std::string toString(Conditional const & node) { + std::stringstream ss; + ss << "(" << toString_generic(node.data.cond) << " ? " + << toString_generic(node.data.e_true) << " : " + << toString_generic(node.data.e_false) << ")"; + return ss.str(); +} + +std::string toString(Parenthesis const & node) { + std::stringstream ss; + ss << "(" << toString_generic(node.data.expr) << ")"; + return ss.str(); +} + +std::string toString(Expression const & node) { + return toString_generic(node.data.expr); +} + +std::string toString(Assign const & node) { + std::stringstream ss; + ss << toString(node.data.argument) << "=" << toString(node.data.value); + return ss.str(); +} + +std::string toString(ObjectRef const & node) { + std::stringstream ss; + ss << "ObjectRef{" << toString(node.data.name); + if (!node.data.config.empty()) { + ss << ", config=" << toString_generic(node.data.config); + } + ss << "}"; + return ss.str(); +} + +std::string toString(Text const &) { + return "Text{}"; +} + +std::string toString(Enum const & node) { + std::stringstream ss; + ss << "Enum" << toString_generic(node.data.enumerators); + return ss.str(); +} + +std::string toString(Choice const & node) { + std::stringstream ss; + ss << "Choice{"; + ss << (node.data.mode == ChoiceKind::Repeat ? "Repeat" : "Select"); + ss << ", " << toString_generic(node.data.source) << "}"; + return ss.str(); +} + +std::string toString(FormatRef const & node) { + std::stringstream ss; + ss << "FormatRef{type=" << toString_generic(node.data.type); + if (!node.data.args.empty()) { + ss << ", args=" << toString_generic(node.data.args); + } + if (!node.data.kwargs.empty()) { + ss << ", kwargs=" << toString_generic(node.data.kwargs); + } + ss << "}"; + return ss.str(); +} + + +std::string toString(Path const &) { + return "Path{...}"; +} + +// refString implementations for ObjectRef and FormatRef +// These provide a more concise, reference-like syntax for debugging + +std::string refString(ObjectRef const & node) { + std::stringstream ss; + ss << toString(node.data.name); + + if (!node.data.config.empty()) { + ss << "("; + bool first = true; + for (auto const & assign : node.data.config) { + if (!first) ss << ", "; + ss << toString(assign.data.argument) << "=" << toString(assign.data.value); + first = false; + } + ss << ")"; + } + + return ss.str(); +} + +std::string refString(FormatRef const & node) { + std::stringstream ss; + + // Type name + std::visit([&ss](auto const & type) { + if constexpr (std::is_same_v, Identifier>) { + ss << type.data.name; + } else if constexpr (std::is_same_v, Text>) { + ss << "Text"; + } else if constexpr (std::is_same_v, Enum>) { + ss << "Enum"; + if (!type.data.enumerators.empty()) { + ss << "{"; + bool first = true; + for (auto const & e : type.data.enumerators) { + if (!first) ss << "|"; + ss << e.data.value; + first = false; + } + ss << "}"; + } + } else if constexpr (std::is_same_v, Choice>) { + ss << "Choice<" << (type.data.mode == ChoiceKind::Repeat ? "Repeat" : "Select") << ">"; + } + }, node.data.type); + + // Arguments and keyword arguments + if (!node.data.args.empty() || !node.data.kwargs.empty()) { + ss << "("; + bool first = true; + + // Positional arguments + for (auto const & arg : node.data.args) { + if (!first) ss << ", "; + ss << toString(arg); + first = false; + } + + // Keyword arguments + for (auto const & kwarg : node.data.kwargs) { + if (!first) ss << ", "; + ss << toString(kwarg.data.argument) << "=" << toString(kwarg.data.value); + first = false; + } + + ss << ")"; + } + + return ss.str(); +} + +} + diff --git a/libs/autocog/compiler/stl/ast/tostring.hxx b/libs/autocog/compiler/stl/ast/tostring.hxx new file mode 100644 index 0000000..1d24686 --- /dev/null +++ b/libs/autocog/compiler/stl/ast/tostring.hxx @@ -0,0 +1,90 @@ +#ifndef AUTOCOG_COMPILER_STL_AST_TOSTRING_HXX +#define AUTOCOG_COMPILER_STL_AST_TOSTRING_HXX + +#include "autocog/compiler/stl/ast.hxx" + +#include +#include +#include + +namespace autocog::compiler::stl::ast { + +std::string toString(std::monostate const &); +std::string toString(Identifier const &); +std::string toString(Integer const &); +std::string toString(Float const &); +std::string toString(Boolean const &); +std::string toString(String const &); +std::string toString(Unary const &); +std::string toString(Binary const &); +std::string toString(Conditional const &); +std::string toString(Parenthesis const &); +std::string toString(Expression const &); +std::string toString(Assign const &); +std::string toString(ObjectRef const &); +std::string toString(Text const &); +std::string toString(Enum const &); +std::string toString(Choice const &); +std::string toString(FormatRef const &); +std::string toString(Path const &); + +std::string refString(ObjectRef const &); +std::string refString(FormatRef const &); + +template +std::enable_if_t, std::string> toString_generic(T const & container) { + if constexpr (is_node_v) { + return toString(container); + } else if constexpr (is_pnode_v) { + if (container) return toString(*container); + return "null"; + } else if constexpr (is_onode_v) { + if (container) return toString(*container); + return "null"; + } else if constexpr (is_nodes_v) { + std::stringstream ss; + ss << "["; + bool first = true; + for (auto const & node : container) { + if (!first) ss << ", "; + ss << toString(node); + first = false; + } + ss << "]"; + return ss.str(); + } else if constexpr (is_pnodes_v) { + std::stringstream ss; + ss << "["; + bool first = true; + for (auto const & pnode : container) { + if (!first) ss << ", "; + if (pnode) ss << toString(*pnode); + else ss << "null"; + first = false; + } + ss << "]"; + return ss.str(); + } else if constexpr (is_variant_v) { + return std::visit([](auto const & node) { + return toString(node); + }, container); + } else if constexpr (is_variants_v) { + std::stringstream ss; + ss << "["; + bool first = true; + for (auto const & vnode : container) { + if (!first) ss << ", "; + ss << std::visit([](auto const & node) { + return toString(node); + }, vnode); + first = false; + } + ss << "]"; + return ss.str(); + } +} + +} + +#endif /* AUTOCOG_COMPILER_STL_AST_TOSTRING_HXX */ + diff --git a/libs/autocog/compiler/stl/ast/traits.hxx b/libs/autocog/compiler/stl/ast/traits.hxx new file mode 100644 index 0000000..7f19751 --- /dev/null +++ b/libs/autocog/compiler/stl/ast/traits.hxx @@ -0,0 +1,177 @@ +#ifndef AUTOCOG_COMPILER_STL_AST_TRAITS_HXX +#define AUTOCOG_COMPILER_STL_AST_TRAITS_HXX + +namespace autocog::compiler::stl::ast { + +// Forward declaration +enum class Tag; +template struct Node; + +// ============================================================================ +// Type aliases for all node container types +// ============================================================================ + +// Basic node type +template +using node_t = Node; + +// Unique pointer to node +template +using pnode_t = std::unique_ptr>; + +// Optional node +template +using onode_t = std::optional>; + +// List of nodes +template +using nodes_t = std::list>; + +// List of unique pointers to nodes +template +using pnodes_t = std::list>>; + +// Variant of nodes (variadic) +template +using variant_t = std::variant...>; + +// List of variant nodes +template +using variants_t = std::list...>>; + +// ============================================================================ +// Type traits for detecting node container types +// ============================================================================ + +// Trait for basic node +template +struct is_node : std::false_type {}; + +template +struct is_node> : std::true_type {}; + +template +inline constexpr bool is_node_v = is_node::value; + +// Trait for unique pointer to node +template +struct is_pnode : std::false_type {}; + +template +struct is_pnode> : std::true_type {}; + +template +inline constexpr bool is_pnode_v = is_pnode::value; + +// Trait for optional node +template +struct is_onode : std::false_type {}; + +template +struct is_onode> : std::true_type {}; + +template +inline constexpr bool is_onode_v = is_onode::value; + +// Trait for list of nodes +template +struct is_nodes : std::false_type {}; + +template +struct is_nodes> : std::true_type {}; + +template +inline constexpr bool is_nodes_v = is_nodes::value; + +// Trait for list of unique pointers +template +struct is_pnodes : std::false_type {}; + +template +struct is_pnodes> : std::true_type {}; + +template +inline constexpr bool is_pnodes_v = is_pnodes::value; + +// Trait for variant nodes +template +struct is_variant : std::false_type {}; + +template +struct is_variant> : std::true_type {}; + +template +inline constexpr bool is_variant_v = is_variant::value; + +// Trait for list of variants +template +struct is_variants : std::false_type {}; + +template +struct is_variants> : std::true_type {}; + +template +inline constexpr bool is_variants_v = is_variants::value; + +// ============================================================================ +// Utility traits +// ============================================================================ + +// Combined trait to check if type is any kind of node container +template +inline constexpr bool is_any_node_container_v = + is_node_v || + is_pnode_v || + is_onode_v || + is_nodes_v || + is_pnodes_v || + is_variant_v || + is_variants_v; + +// Trait to extract the Tag from a single-tag node container +template +struct extract_tag { + static constexpr bool has_single_tag = false; +}; + +template +struct extract_tag> { + static constexpr bool has_single_tag = true; + static constexpr Tag value = tag; +}; + +template +struct extract_tag> { + static constexpr bool has_single_tag = true; + static constexpr Tag value = tag; +}; + +template +struct extract_tag> { + static constexpr bool has_single_tag = true; + static constexpr Tag value = tag; +}; + +template +struct extract_tag> { + static constexpr bool has_single_tag = true; + static constexpr Tag value = tag; +}; + +template +struct extract_tag> { + static constexpr bool has_single_tag = true; + static constexpr Tag value = tag; +}; + +// Helper to check if type has a single tag +template +inline constexpr bool has_single_tag_v = extract_tag::has_single_tag; + +// Helper to get the tag value (only valid if has_single_tag_v is true) +template +inline constexpr Tag extract_tag_v = extract_tag::value; + +} // namespace autocog::compiler::stl::ast + +#endif // AUTOCOG_COMPILER_STL_AST_TRAITS_HXX diff --git a/libs/autocog/compiler/stl/diagnostic.cxx b/libs/autocog/compiler/stl/diagnostic.cxx new file mode 100644 index 0000000..a58d0a9 --- /dev/null +++ b/libs/autocog/compiler/stl/diagnostic.cxx @@ -0,0 +1,91 @@ + +#include "autocog/compiler/stl/diagnostic.hxx" + +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace autocog::compiler::stl { + +Diagnostic::Diagnostic(DiagnosticLevel const level_, std::string message_) : + level(level_), + message(message_), + source_line(std::nullopt), + location(std::nullopt), + notes() +{} + +Diagnostic::Diagnostic(DiagnosticLevel const level_, std::string message_, SourceLocation location_) : + level(level_), + message(message_), + source_line(std::nullopt), + location(location_), + notes() +{} + +Diagnostic::Diagnostic(DiagnosticLevel const level_, std::string message_, std::string source_line_, SourceLocation location_) : + level(level_), + message(message_), + source_line(source_line_), + location(location_), + notes() +{} + +std::string Diagnostic::format(std::unordered_map const & fileids) const { + std::stringstream ss; + + // Main error message + if (location) { + auto const & loc = location.value(); + std::string filepath = "unknown"; + for (auto [fpath,fid]: fileids) { + if (fid == loc.fid) { + filepath = fpath; + break; + } + } + ss << filepath << ":" << loc.line << ":" << loc.column << ": "; + } + switch (level) { + case DiagnosticLevel::Error: ss << "error: "; break; + case DiagnosticLevel::Warning: ss << "warning: "; break; + case DiagnosticLevel::Note: ss << "note: "; break; + } + ss << message << "\n"; + + // Source line with caret + if (source_line) { + ss << " " << source_line.value() << "\n"; + if (location) { + ss << " " << std::string(location.value().column - 1, ' ') << "^\n"; + } + } + + // Additional notes + for (const auto& note : notes) { + ss << " note: " << note << "\n"; + } + + return ss.str(); +} + +CompileError::CompileError( + std::string msg, + std::optional loc +) : + message(std::move(msg)), + location(loc) +{} + +const char * CompileError::what() const noexcept { + return message.c_str(); +} + +} + diff --git a/libs/autocog/compiler/stl/diagnostic.hxx b/libs/autocog/compiler/stl/diagnostic.hxx new file mode 100644 index 0000000..44374f0 --- /dev/null +++ b/libs/autocog/compiler/stl/diagnostic.hxx @@ -0,0 +1,43 @@ +#ifndef AUTOCOG_COMPILER_STL_DIAGNOSTIC_HXX +#define AUTOCOG_COMPILER_STL_DIAGNOSTIC_HXX + +#include "autocog/utilities/exception.hxx" +#include "autocog/compiler/stl/location.hxx" + +#include +#include +#include +#include +#include +#include + +namespace autocog::compiler::stl { + +enum class DiagnosticLevel { Error, Warning, Note }; + +struct Diagnostic { + DiagnosticLevel const level; + std::string message; + std::optional source_line; + std::optional location; + std::vector notes; + + Diagnostic(DiagnosticLevel const level_, std::string message_); + Diagnostic(DiagnosticLevel const level_, std::string message_, SourceLocation location_); + Diagnostic(DiagnosticLevel const level_, std::string message_, std::string source_line_, SourceLocation location_); + + std::string format(std::unordered_map const & fileids) const; +}; + +struct CompileError : std::exception { + std::string message; + std::optional location; + + CompileError(std::string msg, std::optional loc = std::nullopt); + + const char * what() const noexcept override; +}; + +} + +#endif /* AUTOCOG_COMPILER_STL_DIAGNOSTIC_HXX */ diff --git a/libs/autocog/compiler/stl/driver.cxx b/libs/autocog/compiler/stl/driver.cxx new file mode 100644 index 0000000..4cafdcb --- /dev/null +++ b/libs/autocog/compiler/stl/driver.cxx @@ -0,0 +1,215 @@ + +#include "autocog/compiler/stl/driver.hxx" + +#include "autocog/compiler/stl/parser.hxx" +#include "autocog/compiler/stl/symbol-scanner.hxx" +#include "autocog/compiler/stl/evaluate.hxx" +#include "autocog/compiler/stl/instance-scanner.hxx" +//#include "autocog/compiler/stl/instantiate.hxx" + +#include "autocog/compiler/stl/ast/tostring.hxx" +#include "autocog/compiler/stl/ast/printer.hxx" + +#include +#include +#include + +#if VERBOSE +# include +#endif + +namespace autocog::compiler::stl { + +void Driver::emit_error(std::string msg, std::optional const & loc) { +#if DEBUG_Instantiator_emit_error + std::cerr << "Instantiator::emit_error" << std::endl; +#endif + if (loc) { + auto start = loc.value().start; + diagnostics.emplace_back(DiagnosticLevel::Error, msg, start); + } else { + diagnostics.emplace_back(DiagnosticLevel::Error, msg); + } +} + +bool Driver::report_errors() { + for (auto const & diag : diagnostics) { + std::cerr << diag.format(fileids) << std::endl; + switch (diag.level) { + case DiagnosticLevel::Error: errors++; break; + case DiagnosticLevel::Warning: warnings++; break; + case DiagnosticLevel::Note: notes++; break; + } + } + + if (errors > 0) { + std::cerr << "Failed with " << errors << " error(s), " << warnings << " warning(s), and " << notes << " note(s).\n"; + } else if (warnings > 0) { + std::cerr << "Passed with " << warnings << " warning(s) and " << notes << " note(s).\n"; + } else if (notes > 0) { + std::cerr << "Passed with " << notes << " note(s).\n"; + } + diagnostics.clear(); + + return errors > 0; +} + +std::optional Driver::fileid(std::string const & filename) const { + auto it = this->fileids.find(filename); + if (it != this->fileids.end()) { + return it->second; + } + return std::nullopt; +} + +static std::tuple, std::string> parse_scope(std::string const & scope) { + auto first_delim = scope.find("::"); + if (first_delim == std::string::npos) { + throw std::runtime_error("Invalid scope format: " + scope); + } + + int fileid; + try { + fileid = std::stoi(scope.substr(0, first_delim)); + } catch (...) { + throw std::runtime_error("Invalid fileid in scope: " + scope); + } + + auto rest = scope.substr(first_delim + 2); + auto second_delim = rest.find("::"); + + if (second_delim == std::string::npos) { + if (rest.empty()) { + throw std::runtime_error("Empty name in scope: " + scope); + } + return {fileid, std::nullopt, rest}; + } else { + auto object = rest.substr(0, second_delim); + auto name = rest.substr(second_delim + 2); + if (object.empty() || name.empty()) { + throw std::runtime_error("Empty object or name in scope: " + scope); + } + return {fileid, object, name}; + } +} + +//static std::string format_scope(int fid, std::optional const & scope) { +// return scope ? std::to_string(fid) + "::" + scope.value() : std::to_string(fid); +//} + +std::optional Driver::compile__() { + // 1 - Parse all files + + Parser parser(this->diagnostics, this->fileids, this->includes, this->programs, this->inputs); + parser.parse(); + if (this->report_errors()) return 101; + +#if !defined(NDEBUG) + std::cerr << "After parsing (#1):" << std::endl; + for (auto const & program: programs) { + std::cerr << " " << program.data.fid << ": " << program.data.filename << std::endl; + ast::printTagTree(program, std::cerr, "> "); + } +#endif + + // 2 - Collect symbols + SymbolScanner symbol_scanner(*this); + this->traverse_ast(symbol_scanner); + if (this->report_errors()) return 102; + +#if !defined(NDEBUG) + std::cerr << "After collecting symbol (#2):" << std::endl; + this->tables.dump(std::cerr); +#endif + + // Evaluate global defines + + std::vector>> need_evaluation; + for (auto const & [qname, symbol]: this->tables.symbols) { + if (std::holds_alternative(symbol)) { + auto const & defn = std::get(symbol); + auto [fid, obj, name] = parse_scope(qname); + if (!obj) { + auto def_it = this->defines.find(name); + if (def_it != this->defines.end()) { + auto sfid = std::to_string(fid); + auto & context = this->tables.contexts[sfid]; + context[name] = def_it->second; + } else { + need_evaluation.emplace_back(fid,name,defn.node.location); + } + + } + } + } + Evaluator evaluator(this->diagnostics, this->tables); + for (auto [fid,name,loc]: need_evaluation) { + auto sfid = std::to_string(fid); + auto & context = this->tables.contexts[sfid]; + try { + evaluator.retrieve_value(sfid, name, context, loc); // FIXME could avoid a lookup by calling `evaluate` directly... + } catch (CompileError const & e) { + emit_error(e.message, e.location); + } + } + if (this->report_errors()) return 103; + +#if !defined(NDEBUG) + std::cerr << "After evaluating globals (#3):" << std::endl; + this->tables.dump(std::cerr); +#endif + + // Collect instantiations + + InstanceScanner instance_scanner(*this); + this->traverse_ast(instance_scanner); + for ([[maybe_unused]] auto & [objptr, path]: instance_scanner.objects) { +#if !defined(NDEBUG) + std::cerr << "> object: " << ast::refString(*objptr) << std::endl; +#endif + // TODO + + } + for ([[maybe_unused]] auto & [fmtptr, path]: instance_scanner.formats) { +#if !defined(NDEBUG) + std::cerr << "> format: " << ast::refString(*fmtptr) << std::endl; +#endif + // TODO + } + if (this->report_errors()) return 104; + +#if !defined(NDEBUG) + std::cerr << "After collecting instantiations (#4):" << std::endl; + this->tables.dump(std::cerr); +#endif + + // Instantiate all exported prompts associated to input files + +// Instantiator instantiator(programs, diagnostics, symbol_tables); +// +// instantiator.evaluate_defines(); +// if (report_errors()) return 3; +// +// instantiator.generate_symbols(); +// if (report_errors()) return 4; +// +// SymbolTablesChecker symbol_tables_checker(diagnostics, symbol_tables); +// traverse_ast(symbol_tables_checker); +// if (report_errors()) return 5; +// +// instantiator.instantiate(); +// if (report_errors()) return 6; + + return std::nullopt; +} + +std::optional Driver::compile() { + return autocog::utilities::wrap_exception(&Driver::compile__, *this); +} + +int Driver::backend() { + // TODO + return 0; +} + +} diff --git a/libs/autocog/compiler/stl/driver.hxx b/libs/autocog/compiler/stl/driver.hxx new file mode 100644 index 0000000..8673c01 --- /dev/null +++ b/libs/autocog/compiler/stl/driver.hxx @@ -0,0 +1,67 @@ +#ifndef AUTOCOG_COMPILER_STL_DRIVER_HXX +#define AUTOCOG_COMPILER_STL_DRIVER_HXX + +#include "autocog/compiler/stl/diagnostic.hxx" +#include "autocog/compiler/stl/ast.hxx" +#include "autocog/compiler/stl/symbol-table.hxx" + +#include +#include +#include +#include +#include + +namespace autocog::compiler::stl { + +class Driver { + public: + std::list inputs; + std::list includes; + std::unordered_map defines; + + std::optional output = std::nullopt; + bool verbose = false; + + private: + unsigned errors = 0; + unsigned warnings = 0; + unsigned notes = 0; + std::list diagnostics; + std::unordered_map fileids; + + bool report_errors(); + void emit_error(std::string msg, std::optional const & loc); + + public: + std::optional fileid(std::string const & filename) const; + + private: + std::list programs; + + template + void traverse_ast(TraversalT & traversal) { + for (auto const & program: programs) { + try { + program.traverse(traversal); + } catch (CompileError const & e) { + emit_error(e.message, e.location); + } + } + } + + private: + SymbolTable tables; + + private: + std::optional compile__(); + + public: + std::optional compile(); + int backend(); + + friend class SymbolScanner; +}; + +} + +#endif // AUTOCOG_COMPILER_STL_DRIVER_HXX diff --git a/libs/autocog/compiler/stl/eval-utils.txx b/libs/autocog/compiler/stl/eval-utils.txx new file mode 100644 index 0000000..b0c3ed1 --- /dev/null +++ b/libs/autocog/compiler/stl/eval-utils.txx @@ -0,0 +1,151 @@ + +namespace autocog::compiler::stl::eval_utils { + +// Static helper function for arithmetic operations +template +ir::Value evaluateArithmetic(ir::Value const & lhs_val, ir::Value const & rhs_val, + std::optional const & loc = std::nullopt) { + return std::visit([&loc](auto const & l, auto const & r) -> ir::Value { + using L = std::decay_t; + using R = std::decay_t; + + // Special case: string concatenation for Add + if constexpr (Op == ast::OpKind::Add) { + if constexpr (std::is_same_v && std::is_same_v) { + return l + r; + } + } + + // Special case: Mod only works with integers + if constexpr (Op == ast::OpKind::Mod) { + if constexpr (std::is_same_v && std::is_same_v) { + if (r == 0) { + throw CompileError("Modulo by zero", loc); + } + return l % r; + } else { + throw CompileError("Modulo requires integer operands", loc); + } + } + + // General arithmetic for Add, Sub, Mul, Div + if constexpr (std::is_same_v && std::is_same_v) { + if constexpr (Op == ast::OpKind::Add) return l + r; + else if constexpr (Op == ast::OpKind::Sub) return l - r; + else if constexpr (Op == ast::OpKind::Mul) return l * r; + else if constexpr (Op == ast::OpKind::Div) { + if (r == 0) throw CompileError("Division by zero", loc); + // Integer division promotes to float + return static_cast(l) / static_cast(r); + } + else throw std::runtime_error("Invalid arithmetic operator: " + ast::opKindToString(Op)); + } + else if constexpr ((std::is_arithmetic_v && std::is_arithmetic_v) && + (!std::is_same_v && !std::is_same_v)) { + float lf = static_cast(l); + float rf = static_cast(r); + if constexpr (Op == ast::OpKind::Add) return lf + rf; + else if constexpr (Op == ast::OpKind::Sub) return lf - rf; + else if constexpr (Op == ast::OpKind::Mul) return lf * rf; + else if constexpr (Op == ast::OpKind::Div) { + if (rf == 0.0f) throw CompileError("Division by zero", loc); + return lf / rf; + } + else throw std::runtime_error("Invalid arithmetic operator: " + ast::opKindToString(Op)); + } + else { + if constexpr (Op == ast::OpKind::Add) { + throw CompileError("Invalid types for addition", loc); + } else if constexpr (Op == ast::OpKind::Sub) { + throw CompileError("Invalid types for subtraction", loc); + } else if constexpr (Op == ast::OpKind::Mul) { + throw CompileError("Invalid types for multiplication", loc); + } else if constexpr (Op == ast::OpKind::Div) { + throw CompileError("Invalid types for division", loc); + } else { + throw CompileError("Invalid types for arithmetic operation", loc); + } + } + }, lhs_val, rhs_val); +} + +// Static helper function for comparison operations +template +ir::Value evaluateComparison(ir::Value const & lhs_val, ir::Value const & rhs_val, + std::optional const & loc = std::nullopt) { + return std::visit([&loc](auto const & l, auto const & r) -> ir::Value { + using L = std::decay_t; + using R = std::decay_t; + + // Equality and inequality work for all type combinations + if constexpr (Op == ast::OpKind::Eq || Op == ast::OpKind::Neq) { + if constexpr (std::is_same_v) { + if constexpr (Op == ast::OpKind::Eq) return l == r; + else return l != r; + } + else if constexpr ((std::is_arithmetic_v && std::is_arithmetic_v) && + (!std::is_same_v && !std::is_same_v)) { + // Allow numeric comparison between int and float + double ld = static_cast(l); + double rd = static_cast(r); + if constexpr (Op == ast::OpKind::Eq) return ld == rd; + else return ld != rd; + } + else { + // Different types: always false for ==, always true for != + if constexpr (Op == ast::OpKind::Eq) return false; + else return true; + } + } + // Ordering comparisons only work for numbers and strings + else { + if constexpr ((std::is_arithmetic_v && std::is_arithmetic_v) && + (!std::is_same_v && !std::is_same_v)) { + double ld = static_cast(l); + double rd = static_cast(r); + if constexpr (Op == ast::OpKind::Lt) return ld < rd; + else if constexpr (Op == ast::OpKind::Lte) return ld <= rd; + else if constexpr (Op == ast::OpKind::Gt) return ld > rd; + else if constexpr (Op == ast::OpKind::Gte) return ld >= rd; + else throw std::runtime_error("Invalid comparison operator"); + } + else if constexpr (std::is_same_v && std::is_same_v) { + if constexpr (Op == ast::OpKind::Lt) return l < r; + else if constexpr (Op == ast::OpKind::Lte) return l <= r; + else if constexpr (Op == ast::OpKind::Gt) return l > r; + else if constexpr (Op == ast::OpKind::Gte) return l >= r; + else throw std::runtime_error("Invalid comparison operator"); + } + else { + throw CompileError("Invalid types for ordering comparison", loc); + } + } + }, lhs_val, rhs_val); +} + +// Static helper function for logical operations +template +ir::Value evaluateLogical(ir::Value const & lhs_val, ir::Value const & rhs_val, + std::optional const & loc = std::nullopt) { + return std::visit([&loc](auto const & l, auto const & r) -> ir::Value { + using L = std::decay_t; + using R = std::decay_t; + + if constexpr (std::is_same_v && std::is_same_v) { + if constexpr (Op == ast::OpKind::And) return l && r; + else if constexpr (Op == ast::OpKind::Or) return l || r; + else throw std::runtime_error("Invalid logical operator: " + ast::opKindToString(Op)); + } else { + if constexpr (Op == ast::OpKind::And) { + throw CompileError("Logical AND requires boolean operands", loc); + } else if constexpr (Op == ast::OpKind::Or) { + throw CompileError("Logical OR requires boolean operands", loc); + } else { + throw std::runtime_error("Invalid logical operator: " + ast::opKindToString(Op)); + } + } + }, lhs_val, rhs_val); +} + +} // namespace autocog::compiler + diff --git a/libs/autocog/compiler/stl/evaluate.cxx b/libs/autocog/compiler/stl/evaluate.cxx new file mode 100644 index 0000000..ff0a4bf --- /dev/null +++ b/libs/autocog/compiler/stl/evaluate.cxx @@ -0,0 +1,290 @@ + +#include "evaluate.hxx" + +#include "autocog/compiler/stl/symbol-table.hxx" + +#include + +#if VERBOSE +# include +#endif + +namespace autocog::compiler::stl { + +Evaluator::Evaluator(std::list & diagnostics_, SymbolTable & tables_) : + diagnostics(diagnostics_), + tables(tables_) +{} + +ir::Value Evaluator::evaluate(std::string const & scope, ast::Unary const & op, ir::VarMap & varmap) { + ir::Value operand_val = evaluate(scope, *op.data.operand, varmap); + + switch (op.data.kind) { + case ast::OpKind::Neg: + return std::visit([&op](auto const & v) -> ir::Value { + using V = std::decay_t; + if constexpr (std::is_same_v) { + return -v; + } else if constexpr (std::is_same_v) { + return -v; + } else { + throw CompileError("Cannot negate non-numeric value", op.location); + } + }, operand_val); + + case ast::OpKind::Not: + return std::visit([&op](auto const & v) -> ir::Value { + using V = std::decay_t; + if constexpr (std::is_same_v) { + return !v; + } else { + throw CompileError("Cannot apply 'not' to non-boolean value", op.location); + } + }, operand_val); + + default: + throw std::runtime_error("Invalid unary operator kind: " + ast::opKindToString(op.data.kind)); + } +} + +ir::Value Evaluator::evaluate(std::string const & scope, ast::Binary const & op, ir::VarMap & varmap) { + ir::Value lhs_val = evaluate(scope, *op.data.lhs, varmap); + ir::Value rhs_val = evaluate(scope, *op.data.rhs, varmap); + + switch (op.data.kind) { + // Arithmetic operators + case ast::OpKind::Add: + return eval_utils::evaluateArithmetic(lhs_val, rhs_val, op.location); + case ast::OpKind::Sub: + return eval_utils::evaluateArithmetic(lhs_val, rhs_val, op.location); + case ast::OpKind::Mul: + return eval_utils::evaluateArithmetic(lhs_val, rhs_val, op.location); + case ast::OpKind::Div: + return eval_utils::evaluateArithmetic(lhs_val, rhs_val, op.location); + case ast::OpKind::Mod: + return eval_utils::evaluateArithmetic(lhs_val, rhs_val, op.location); + + // Comparison operators + case ast::OpKind::Eq: + return eval_utils::evaluateComparison(lhs_val, rhs_val, op.location); + case ast::OpKind::Neq: + return eval_utils::evaluateComparison(lhs_val, rhs_val, op.location); + case ast::OpKind::Lt: + return eval_utils::evaluateComparison(lhs_val, rhs_val, op.location); + case ast::OpKind::Lte: + return eval_utils::evaluateComparison(lhs_val, rhs_val, op.location); + case ast::OpKind::Gt: + return eval_utils::evaluateComparison(lhs_val, rhs_val, op.location); + case ast::OpKind::Gte: + return eval_utils::evaluateComparison(lhs_val, rhs_val, op.location); + + // Logical operators + case ast::OpKind::And: + return eval_utils::evaluateLogical(lhs_val, rhs_val, op.location); + case ast::OpKind::Or: + return eval_utils::evaluateLogical(lhs_val, rhs_val, op.location); + + default: + throw std::runtime_error("Invalid binary operator kind"); + } +} + +ir::Value Evaluator::evaluate(std::string const & scope, ast::Conditional const & op, ir::VarMap & varmap) { + ir::Value cond_val = evaluate(scope, *op.data.cond, varmap); + + bool condition = std::visit([&op](auto const & v) -> bool { + using V = std::decay_t; + if constexpr (std::is_same_v) { + return v; + } else { + throw CompileError("Conditional expression requires boolean condition", op.location); + } + }, cond_val); + + if (condition) { + return evaluate(scope, *op.data.e_true, varmap); + } else { + return evaluate(scope, *op.data.e_false, varmap); + } +} + +ir::Value Evaluator::evaluate(std::string const & scope, ast::String const & fstring, ir::VarMap & varmap) { + if (!fstring.data.is_format) { + return fstring.data.value; + } + + const std::string & fmt = fstring.data.value; + std::string result; + result.reserve(fmt.length() * 1.5); + + size_t i = 0; + while (i < fmt.length()) { + // Handle escaped braces + if (i + 1 < fmt.length()) { + if (fmt[i] == '{' && fmt[i + 1] == '{') { + result += '{'; + i += 2; + continue; + } + if (fmt[i] == '}' && fmt[i + 1] == '}') { + result += '}'; + i += 2; + continue; + } + } + + // Handle variable substitution + if (fmt[i] == '{') { + // Find closing brace + size_t j = i + 1; + while (j < fmt.length() && fmt[j] != '}') { + j++; + } + + if (j >= fmt.length()) { + throw CompileError("Unclosed '{' in format string", fstring.location); + } + + // Extract variable name + std::string var_name = fmt.substr(i + 1, j - i - 1); + + // Validate it's a valid identifier + if (var_name.empty()) { + throw CompileError("Empty variable name in format string", fstring.location); + } + + // Simple identifier validation (first char is letter or _, rest are alphanumeric or _) + if (!std::isalpha(var_name[0]) && var_name[0] != '_') { + throw CompileError("Invalid variable name in format string: " + var_name, fstring.location); + } + for (size_t k = 1; k < var_name.length(); k++) { + if (!std::isalnum(var_name[k]) && var_name[k] != '_') { + throw CompileError("Invalid variable name in format string: " + var_name, fstring.location); + } + } + + // Look up variable + ir::Value value = retrieve_value(scope, var_name, varmap, fstring.location); + + // Convert value to string and append + result += std::visit([](auto const & v) -> std::string { + using V = std::decay_t; + if constexpr (std::is_same_v) { + return std::to_string(v); + } else if constexpr (std::is_same_v) { + // Format float to avoid unnecessary trailing zeros + std::string s = std::to_string(v); + // Remove trailing zeros after decimal point + if (s.find('.') != std::string::npos) { + s.erase(s.find_last_not_of('0') + 1, std::string::npos); + // Remove decimal point if it's the last character + if (s.back() == '.') { + s.pop_back(); + } + } + return s; + } else if constexpr (std::is_same_v) { + return v ? "true" : "false"; + } else if constexpr (std::is_same_v) { + return v; + } else { + return ""; + } + }, value); + + i = j + 1; + } else if (fmt[i] == '}') { + throw CompileError("Unmatched '}' in format string", fstring.location); + } else { + result += fmt[i]; + i++; + } + } + + return result; +} + +ir::Value Evaluator::evaluate(std::string const & scope, ast::Expression const & expr, ir::VarMap & varmap) { + return std::visit([&](auto const & e) -> ir::Value { + using T = std::decay_t; + if constexpr (std::is_same_v) { + return e.data.value; + } else if constexpr (std::is_same_v) { + return e.data.value; + } else if constexpr (std::is_same_v) { + return e.data.value; + } else if constexpr (std::is_same_v) { + return evaluate(scope, e, varmap); + } else if constexpr (std::is_same_v) { + return retrieve_value(scope, e.data.name, varmap, e.location); + } else if constexpr (std::is_same_v) { + return evaluate(scope, e, varmap); + } else if constexpr (std::is_same_v) { + return evaluate(scope, e, varmap); + } else if constexpr (std::is_same_v) { + return evaluate(scope, e, varmap); + } else if constexpr (std::is_same_v) { + return evaluate(scope, *(e.data.expr), varmap); + } else { + throw std::runtime_error("Unknown expression variant type"); + } + }, expr.data.expr); +} + +#define DEBUG_Evaluator_retrieve_value VERBOSE && 0 + +ir::Value Evaluator::retrieve_value( + std::string const & scope, + std::string const & varname, + ir::VarMap & varmap, + std::optional const & loc +) { +#if DEBUG_Evaluator_retrieve_value + std::cerr << "Evaluator::retrieve_value(" << varname << ")" << std::endl; +#endif + ir::Value value = nullptr; + + auto varmap_it = varmap.find(varname); + if (varmap_it != varmap.end()) { + value = varmap_it->second; + } else { + auto sym_it = this->tables.symbols.find(scope+"::"+varname); + if (sym_it == this->tables.symbols.end()) { + // TODO get from parent scope + throw std::runtime_error("NIY get value from parent scope."); + } + auto const & sym = sym_it->second; + if (std::holds_alternative(sym)) { + auto const & defn = std::get(sym).node; + + if (defn.data.is_argument) { + throw std::runtime_error("Argument `" + varname + "` should have been found in the variable map."); + } + + if (!defn.data.init) { + throw CompileError("Define `" + varname + "` has no initializer to evaluate.", loc); + } + + varmap[varname] = nullptr; // causes cycles to return error values + value = evaluate(scope, defn.data.init.value(), varmap); + varmap[varname] = value; + + } else { + // FIXME could look up the value in the parent scope but do we authorize shadowing of globals by object of different kind? + throw CompileError("Found a symbol for another object than a define when looking up " + varname + " for evaluation!", loc); + } + } + if (std::holds_alternative(value)) { + throw CompileError( + "Found error value when retriving variable `" + varname + "`." + + "If no previous error was reported then it is likely a circular " + + "dependency between variable initializers.", + loc + ); + } + return value; +} + + +} // namespace autocog::compiler + diff --git a/libs/autocog/compiler/stl/evaluate.hxx b/libs/autocog/compiler/stl/evaluate.hxx new file mode 100644 index 0000000..bafa397 --- /dev/null +++ b/libs/autocog/compiler/stl/evaluate.hxx @@ -0,0 +1,58 @@ +#ifndef AUTOCOG_COMPILER_STL_INSTANTIATE_HXX +#define AUTOCOG_COMPILER_STL_INSTANTIATE_HXX + +#include "autocog/compiler/stl/ast.hxx" +#include "autocog/compiler/stl/ir.hxx" +#include "autocog/compiler/stl/diagnostic.hxx" + +#include +#include +#include +#include + +namespace autocog::compiler::stl { + +class SymbolTable; +class Evaluator { + private: + std::list & diagnostics; + SymbolTable & tables; + + private: + void emit_error(std::string msg, std::optional const & loc); + + private: + ir::Value evaluate( + std::string const &, ast::Expression const &, ir::VarMap & + ); + + ir::Value evaluate( + std::string const &, ast::Unary const &, ir::VarMap & + ); + + ir::Value evaluate( + std::string const &, ast::Binary const &, ir::VarMap & + ); + + ir::Value evaluate( + std::string const &, ast::Conditional const &, ir::VarMap & + ); + + ir::Value evaluate( + std::string const &, ast::String const &, ir::VarMap & + ); + + public: + Evaluator(std::list & diagnostics_, SymbolTable & tables); + + ir::Value retrieve_value( + std::string const &, std::string const &, ir::VarMap &, + std::optional const & = std::nullopt + ); +}; + +} + +#include "autocog/compiler/stl/eval-utils.txx" + +#endif // AUTOCOG_COMPILER_STL_INSTANTIATE_HXX diff --git a/libs/autocog/compiler/stl/instance-scanner.cxx b/libs/autocog/compiler/stl/instance-scanner.cxx new file mode 100644 index 0000000..bb59569 --- /dev/null +++ b/libs/autocog/compiler/stl/instance-scanner.cxx @@ -0,0 +1,33 @@ + +#include "autocog/compiler/stl/instance-scanner.hxx" +#include "autocog/compiler/stl/driver.hxx" + +#if VERBOSE +# include +#endif + +namespace autocog::compiler::stl { + +InstanceScanner::InstanceScanner( + Driver & driver_ +) : + driver(driver_), + path(), + objects(), + formats() +{} + +template <> +void InstanceScanner::pre(ast::ObjectRef const & node) { + this->objects.emplace(&node, this->path); + path.push_back(&node); +} + +template <> +void InstanceScanner::pre(ast::FormatRef const & node) { + this->formats.emplace(&node, this->path); + path.push_back(&node); +} + +} + diff --git a/libs/autocog/compiler/stl/instance-scanner.hxx b/libs/autocog/compiler/stl/instance-scanner.hxx new file mode 100644 index 0000000..3dafaf0 --- /dev/null +++ b/libs/autocog/compiler/stl/instance-scanner.hxx @@ -0,0 +1,52 @@ +#ifndef AUTOCOG_COMPILER_STL_INSTANCE_SCANNER_HXX +#define AUTOCOG_COMPILER_STL_INSTANCE_SCANNER_HXX + +#include "autocog/compiler/stl/ast.hxx" + +#include + +namespace autocog::compiler::stl { + +class Driver; + +class InstanceScanner { + public: + using NodePath = std::vector; + + private: + Driver & driver; + NodePath path; + std::unordered_map objects; + std::unordered_map formats; + + public: + InstanceScanner(Driver &); + + template + bool shortcut(ast::Node const &) const { + return false; + } + + template + void pre(ast::Node const & node) { + path.push_back(&node); + } + + template + void post(ast::Node const &) { + path.pop_back(); + } + + friend Driver; +}; + +template <> +void InstanceScanner::pre(ast::ObjectRef const &); + +template <> +void InstanceScanner::pre(ast::FormatRef const &); + + +} + +#endif // AUTOCOG_COMPILER_STL_INSTANCE_SCANNER_HXX diff --git a/libs/autocog/compiler/stl/instantiate.cxx b/libs/autocog/compiler/stl/instantiate.cxx new file mode 100644 index 0000000..bee8941 --- /dev/null +++ b/libs/autocog/compiler/stl/instantiate.cxx @@ -0,0 +1,430 @@ + +#include "instantiate.hxx" + +#include +#include +#include +#include + +#if VERBOSE +# include +#endif + +namespace autocog::compiler::stl { + +#define DEBUG_Instantiator_emit_error VERBOSE && 0 + +void Instantiator::emit_error(std::string msg, std::optional const & loc) { +#if DEBUG_Instantiator_emit_error + std::cerr << "Instantiator::emit_error" << std::endl; +#endif + if (loc) { + auto start = loc.value().start; + diagnostics.emplace_back(DiagnosticLevel::Error, msg, start); + } else { + diagnostics.emplace_back(DiagnosticLevel::Error, msg); + } +} + +Instantiator::Instantiator( + std::list<, ast::Program> const & programs_, + std::list & diagnostics_, + SymbolTables & tables_ +) : + programs(programs_), + diagnostics(diagnostics_), + tables(tables_), + exports(), + instantiations(), + record_cache() +{} + +[[maybe_unused]] static std::string valueToString(ir::Value const & value) { + return std::visit([](auto const & v) -> std::string { + using V = std::decay_t; + if constexpr (std::is_same_v) { + return "int(" + std::to_string(v) + ")"; + } else if constexpr (std::is_same_v) { + std::string s = std::to_string(v); + // Remove trailing zeros after decimal point + if (s.find('.') != std::string::npos) { + s.erase(s.find_last_not_of('0') + 1, std::string::npos); + if (s.back() == '.') { + s.pop_back(); + } + } + return "float(" + s + ")"; + } else if constexpr (std::is_same_v) { + return "bool(" + std::string(v ? "true" : "false") + ")"; + } else if constexpr (std::is_same_v) { + return "string(\"" + v + "\")"; + } else { + return "unknown()"; + } + }, value); +} + +// FNV-1a 64-bit hash for deterministic cross-platform hashing +uint64_t fnv1a_hash(const std::string& str) { + uint64_t hash = 0xcbf29ce484222325ULL; // FNV-1a 64-bit offset basis + for (unsigned char c : str) { + hash ^= static_cast(c); // XOR with byte + hash *= 0x100000001b3ULL; // Multiply by FNV-1a 64-bit prime + } + return hash; +} + +std::string mangle(std::string const & name, ir::VarMap const & varmap) { + if (varmap.empty()) { + return name; + } + + std::string mangled = name; + + // Sort parameters by name for deterministic mangling + std::vector> sorted_params(varmap.begin(), varmap.end()); + std::sort(sorted_params.begin(), sorted_params.end()); + + for (auto const & [param_name, value] : sorted_params) { + mangled += "__" + param_name + "_"; + + std::visit([&mangled](auto const & v) { + using V = std::decay_t; + if constexpr (std::is_same_v) { + mangled += "i" + std::to_string(v); + } else if constexpr (std::is_same_v) { + // Handle special float values + if (std::isnan(v)) { + mangled += "fNaN"; + } else if (std::isinf(v)) { + mangled += v > 0 ? "fInf" : "fNegInf"; + } else { + // Use hex float representation for exact value preservation + std::ostringstream oss; + oss << std::hexfloat << v; + std::string hex_str = oss.str(); + // Replace problematic characters + std::replace(hex_str.begin(), hex_str.end(), '.', 'd'); + std::replace(hex_str.begin(), hex_str.end(), '+', 'p'); + std::replace(hex_str.begin(), hex_str.end(), '-', 'm'); + mangled += "f" + hex_str; + } + } else if constexpr (std::is_same_v) { + mangled += v ? "bT" : "bF"; + } else if constexpr (std::is_same_v) { + // Use FNV-1a hash for string values + uint64_t hash_value = fnv1a_hash(v); + std::ostringstream oss; + oss << "s" << std::hex << hash_value; + mangled += oss.str(); + } else if constexpr (std::is_same_v) { + mangled += "null"; + } + }, value); + } + + return mangled; +} + +#define DEBUG_Instantiator_evaluate_defines VERBOSE && 0 + +void Instantiator::evaluate_defines() { +#if DEBUG_Instantiator_evaluate_defines + std::cerr << "Instantiator::evaluate_defines" << std::endl; +#endif + for (auto const & program: programs) { +#if DEBUG_Instantiator_evaluate_defines + std::cerr << " " << program.data.filename << std::endl; +#endif + auto & varmap = tables.globals[program.data.filename]; + for (auto const & [varname, defn]: program.data.defines) { + if (defn.data.name != varname) { + throw std::runtime_error("Inconsistency of Define statement name."); + } + + if (defn.data.argument) { + emit_error("Top level definition must not be arguments!", defn.location); + return; + } + + if (!defn.data.init) { + emit_error("Define without initializer!", defn.location); + return; + } + + try { + retrieve_value(program, defn.data.name, varmap, defn.location); + } catch (CompileError const & e) { + emit_error(e.message, e.location); + } +#if DEBUG_Instantiator_evaluate_defines + std::cerr << "> varmap[" << defn.data.name << "] = " << valueToString(varmap[defn.data.name]) << std::endl; +#endif + } + } +} + +void Instantiator::scan_import_statement(SymbolTable & symtbl, ast::Import const & import) { + auto const & filename = import.data.file; + bool has_stl_ext = filename.size() >= 4 && ( filename.rfind(".stl") == filename.size() - 4 ); + bool has_py_ext = filename.size() >= 3 && ( filename.rfind(".py") == filename.size() - 3 ); + if (has_stl_ext) { + auto prog_it = programs.begin(); + while (prog_it != programs.end()) { + if (program.data.filename == filename) break; + ++prog_it; + } + if (prog_it == programs.end()) { + throw std::runtime_error("STL file `" + filename + "` in import statement was not parsed."); + } + for (auto [alias,target]: import.data.targets) { + symtbl.emplace(alias, UnresolvedImport(filename, target, import)); + } + } else if (has_py_ext) { + for (auto const & [alias,target]: import.data.targets) { + symtbl.emplace(alias, PythonSymbol(filename, target, alias)); + } + } else { + emit_error("Imported file with unknown extension.", import.location); + } +} + +static void replace_symbol(SymbolTable & symtbl, std::string const & alias, AnySymbol const & source) { + symtbl.erase(alias); + std::visit([&symtbl, &alias](auto const & sym) { + using T = std::decay_t; + if constexpr (std::is_same_v || std::is_same_v) { + symtbl.emplace(alias, T(sym.scope, sym.node, sym.name)); + } else if constexpr (std::is_same_v) { + symtbl.emplace(alias, T(sym.filename, sym.callable, sym.name)); + } else { + throw std::runtime_error("Helper function `replace_symbol` should never have been called with an `UnresolvedImport`."); + } + }, source); +} + +#define DEBUG_Instantiator_generate_symbols VERBOSE && 0 + +void Instantiator::generate_symbols() { +#if DEBUG_Instantiator_generate_symbols + std::cerr << "Instantiator::generate_symbols" << std::endl; +#endif + for (auto const & program: programs) { + SymbolTable & symtbl = tables.symbols[program.data.filename]; + for (auto const & import_statement: program.data.imports) { + scan_import_statement(symtbl, import_statement); + } + for (auto const & [name, record]: program.data.records) { + symtbl.emplace(name, RecordSymbol(program, record, name)); + } + for (auto const & [name, prompt]: program.data.prompts) { + symtbl.emplace(name, PromptSymbol(program, prompt, name)); + } + } + bool resolved_symbols = true; + while (resolved_symbols) { + resolved_symbols = false; + for (auto & [fname, symtbl]: tables.symbols) { +#if DEBUG_Instantiator_generate_symbols + std::cerr << " IN " << fname << std::endl; +#endif + std::vector> unresolved_symbols; + for (auto & [alias, symbol]: symtbl) { +#if DEBUG_Instantiator_generate_symbols + std::cerr << " SEE " << alias << std::endl; +#endif + if (std::holds_alternative(symbol)) { + auto & unresolved = std::get(symbol); +#if DEBUG_Instantiator_generate_symbols + std::cerr << " unresolved " << unresolved.objname << " from " << unresolved.filename << std::endl; +#endif + unresolved_symbols.emplace_back(alias, unresolved); + } + } +#if DEBUG_Instantiator_generate_symbols + std::cerr << " FOUND " << unresolved_symbols.size() << " unresolved symbols" << std::endl; +#endif + for (auto & [alias, unresolved]: unresolved_symbols) { + auto & imported_symtbl = tables.symbols[unresolved.filename]; + auto sym_it = imported_symtbl.find(unresolved.objname); + if (sym_it == imported_symtbl.end()) { + emit_error("Trying to import a non-existant object `" + unresolved.objname + "`", unresolved.import.location); + symtbl.erase(alias); + } else if (!std::holds_alternative(sym_it->second)) { +#if DEBUG_Instantiator_generate_symbols + std::cerr << " resolving " << alias << " from " << fname << std::endl; +#endif + replace_symbol(symtbl, alias, sym_it->second); + resolved_symbols = true; + } + } + } + } + + for (auto & [fname, symtbl]: tables.symbols) { + for (auto & [alias, symbol]: symtbl) { + if (std::holds_alternative(symbol)) { + auto & unresolved = std::get(symbol); + emit_error("Could not resolve `" + unresolved.objname + "` (likely a circular dependency).", unresolved.import.location); + } + } + } +} + +#define DEBUG_Instantiator_scoped_context VERBOSE && 0 + +template +ir::VarMap Instantiator::scoped_context( + ScopeT const & scope, + Kwargs const & kwargs, + ir::VarMap const & parent_context, + std::optional const & loc +) { +#if DEBUG_Instantiator_scoped_context + std::cerr << "Instantiator::scoped_context" << std::endl; +#endif + + // Step 1: Validate all kwargs correspond to actual arguments + for (auto const& [name, expr] : kwargs) { +#if DEBUG_Instantiator_scoped_context + std::cerr << " kwargs: " << name << std::endl; +#endif + auto def_it = scope.data.defines.find(name); + if (def_it == scope.data.defines.end()) { + throw CompileError("Unknown parameter: " + name, loc); + } + auto const & defn = def_it->second; +#if DEBUG_Instantiator_scoped_context + std::cerr << " defn: " << defn.data.name << std::endl; + std::cerr << " arg: " << defn.data.argument << std::endl; + std::cerr << " init: " << (defn.data.init?"present":"") << std::endl; +#endif + if (!defn.data.argument) { + throw CompileError("Non-argument parameter: " + name, defn.location); + } + } + + // Step 2: Evaluate kwargs in parent context ONLY + ir::VarMap evaluated_kwargs; + ir::VarMap parent_copy = parent_context; // for const correctness + for (auto const& [name, expr] : kwargs) { + evaluated_kwargs[name] = evaluate(scope, expr, parent_copy); + } + + // Step 3: Build object varmap + ir::VarMap object_context = parent_context; + + // Step 4: Remove shadowed variables (all defines shadow parent scope) + for (auto const& [name, define] : scope.data.defines) { + object_context.erase(name); + } + + // Step 5: Insert evaluated kwargs into object varmap + for (auto const& [name, value] : evaluated_kwargs) { + object_context[name] = value; + } + + // Step 6: Validate all arguments have values (kwargs or init) + for (auto const& [name, define] : scope.data.defines) { + if (define.data.argument) { + if (kwargs.find(name) == kwargs.end() && !define.data.init) { + throw CompileError("Missing required argument: " + name); + } + } + } + + // Step 7: Force evaluation of all defines (both arguments and locals) + for (auto const& [name, define] : scope.data.defines) { + retrieve_value(scope, name, object_context); + } + + return object_context; +} + +template +std::string mangle( + ObjectT const & object, + ir::VarMap const & context__ +) { + ir::VarMap context; + for (auto const& [name, define] : object.data.defines) { + if (define.data.argument) { + context[name] = context__.at(name); + } + } + return mangle(object.data.name, context); +} + +#define DEBUG_Instantiator_instantiate VERBOSE && 1 + +template <> +std::string Instantiator::instantiate( + ast::Record const & record, + Kwargs const & kwargs, + ir::VarMap const & context__, + std::optional const & loc +) { + ir::VarMap context = scoped_context(record, kwargs, context__, loc); + std::string mangled_name = mangle(record.data.name, context); + + if (record_cache.find(mangled_name) == record_cache.end()) { + record_cache.emplace(mangled_name, ir::Record{record.data.name, context, mangled_name}); + // TODO + } + + return mangled_name; +} + +template <> +std::string Instantiator::instantiate( + ast::Prompt const & prompt, + Kwargs const & kwargs, + ir::VarMap const & context__, + std::optional const & loc +) { + ir::VarMap context = scoped_context(prompt, kwargs, context__, loc); + std::string mangled_name = mangle(prompt.data.name, context); + + if (instantiations.find(mangled_name) == instantiations.end()) { + instantiations.emplace(mangled_name, ir::Prompt{prompt.data.name, context, mangled_name}); + // TODO explore this prompt's channels and flow to find references for more instantiations + } + + return mangled_name; +} + +void Instantiator::instantiate() { +#if DEBUG_Instantiator_instantiate + std::cerr << "Instantiator::instantiate" << std::endl; +#endif + for (auto const & program: programs) { + SymbolTable const & symtbl = tables.symbols[program.data.filename]; +#if DEBUG_Instantiator_instantiate + std::cerr << " IN " << program.data.filename << std::endl; +#endif + for (auto const & exported: program.data.exports) { +#if DEBUG_Instantiator_instantiate + std::cerr << " EXPORT " << exported.data.alias << " from " << exported.data.target << std::endl; +#endif + auto target_it = symtbl.find(exported.data.target); + if (target_it == symtbl.end()) { + emit_error("Trying to export something that does not exist.", exported.location); + continue; + } else if (!std::holds_alternative(target_it->second)) { + emit_error("Only prompt can be exported.", exported.location); + continue; + } + + auto const & symbol = std::get(target_it->second); + try { + std::string mangled_name = instantiate(symbol.node, exported.data.kwargs, tables.globals[symbol.scope.data.filename]); + tables.exports.emplace(exported.data.alias, mangled_name); + } catch (CompileError const & e) { + emit_error(e.message, e.location); + } + } + } +} + +} // namespace autocog::compiler + diff --git a/libs/autocog/compiler/stl/instantiate.hxx b/libs/autocog/compiler/stl/instantiate.hxx new file mode 100644 index 0000000..540980c --- /dev/null +++ b/libs/autocog/compiler/stl/instantiate.hxx @@ -0,0 +1,79 @@ +#ifndef AUTOCOG_COMPILER_STL_INSTANTIATE_HXX +#define AUTOCOG_COMPILER_STL_INSTANTIATE_HXX + +#include "autocog/compiler/stl/ast.hxx" +#include "autocog/compiler/stl/ir.hxx" +#include "autocog/compiler/stl/symbol-table.hxx" + +#include "autocog/compiler/stl/diagnostic.hxx" + +#include +#include +#include +#include + +namespace autocog::compiler::stl { + +std::string mangle( + std::string const &, + ir::VarMap const & +); + +using Kwargs = std::unordered_map; + +template +std::string mangle( + ObjectT const &, + ir::VarMap const & +); + +class Instantiator { + private: + std::list const & programs; + std::list & diagnostics; + SymbolTables & tables; + + std::unordered_map instantiations; + std::unordered_map record_cache; + + private: + void emit_error(std::string msg, std::optional const & loc); + + private: + void scan_import_statement( + SymbolTable & symtbl, + ast::Import const & import + ); + + private: + template + ir::VarMap scoped_context( + ScopeT const &, + Kwargs const &, + ir::VarMap const &, + std::optional const & = std::nullopt + ); + + template + std::string instantiate( + ObjectT const &, + Kwargs const &, + ir::VarMap const &, + std::optional const & = std::nullopt + ); + + + public: + Instantiator( + std::list const & programs_, + std::list & diagnostics_, + SymbolTables & tables_ + ); + void evaluate_defines(); + void generate_symbols(); + void instantiate(); +}; + +} + +#endif // AUTOCOG_COMPILER_STL_INSTANTIATE_HXX diff --git a/libs/autocog/compiler/stl/ir.hxx b/libs/autocog/compiler/stl/ir.hxx new file mode 100644 index 0000000..b0f80b8 --- /dev/null +++ b/libs/autocog/compiler/stl/ir.hxx @@ -0,0 +1,143 @@ +#ifndef AUTOCOG_COMPILER_STL_IR_HXX +#define AUTOCOG_COMPILER_STL_IR_HXX + +#include +#include +#include +#include +#include +#include + +namespace autocog::compiler::stl::ir { + +using Value = std::variant; +using VarMap = std::unordered_map; +using DocPath = std::list; + +struct Object { + std::string name; + std::optional annotation; + + Object(std::string name_) : + name(name_) + {} +}; + +// Forward declarations +struct Record; +using RecordPtr = std::unique_ptr; + +struct Format : public Object { + VarMap kwargs; + + Format(std::string name_) : + Object(std::move(name_)) + {} +}; + +struct Text : public Format { + // TODO: vocab source when we support it + + Text(std::string name_) : + Format(std::move(name_)) + {} +}; + +struct Enum : public Format { + std::set enumerators; + + Enum(std::string name_) : + Format(std::move(name_)), + enumerators() + {} +}; + +enum class ChoiceMode { Select, Repeat }; + +struct Choice : public Format { + ChoiceMode mode; + DocPath source; + + Choice(std::string name_, ChoiceMode mode_) : + Format(std::move(name_)), + mode(mode_), + source() + {} +}; + +struct Field : public Object { + std::optional lower; + std::optional upper; + std::variant type; + + Field(std::string name_) : + Object(std::move(name_)), + lower(std::nullopt), + upper(std::nullopt), + type(std::monostate{}) + {} +}; + +struct Record : public Object { + VarMap context; + std::string mangled; + std::vector fields; + + Record(std::string name_, VarMap context_, std::string mangled_) : + Object(std::move(name_)), + context(std::move(context_)), + mangled(std::move(mangled_)), + fields() + {} +}; + +// For Call sources in channels +struct CallInfo { + std::string entry; + + struct Kwarg { + std::string name; + DocPath source; + bool mapped; + }; + std::vector kwargs; + + std::unordered_map binds; +}; + +struct Prompt : public Record { + + // Channels (optional) + struct ChannelLink { + DocPath target; + std::variant source; + }; + std::vector channels; + + // Flow (optional) + struct FlowEdge { + std::string target_prompt; + std::optional label; + }; + std::vector flows; + + // Return (optional) + struct ReturnInfo { + std::optional label; + struct Field { + DocPath source; + std::optional alias; + }; + std::vector fields; + }; + std::optional return_info; + + Prompt(std::string name_, VarMap context_, std::string mangled_) : + Record(std::move(name_), std::move(context_), std::move(mangled_)), + channels(), flows(), return_info(std::nullopt) + {} +}; + +} + +#endif /* AUTOCOG_COMPILER_STL_IR_HXX */ diff --git a/libs/autocog/compiler/stl/lexer.l b/libs/autocog/compiler/stl/lexer.l new file mode 100644 index 0000000..8b771ae --- /dev/null +++ b/libs/autocog/compiler/stl/lexer.l @@ -0,0 +1,138 @@ +%top{ +#include "autocog/compiler/stl/token.hxx" +#include "autocog/compiler/stl/location.hxx" +#include +%} + +%class{ +private: + int fid = -1; + int line_number = 1; + int column_number = 1; + int last_column = 1; + unsigned current_offset = 0; + + void update_location() { + last_column = column_number; + column_number += size(); + current_offset += size(); + } + + void newline() { + line_number++; + last_column = column_number; + column_number = 1; + current_offset++; + } + +public: + void set_file_id(int fid_) { fid = fid_; } + + SourceLocation current_location() const { + return {fid, line_number, last_column, current_offset}; + } + + Token advance() { + TokenType tok = (TokenType)(this->lex()); + Token token{ tok, this->str(), this->current_location() }; + return token; + } +%} + +%option c++ noyywrap +%option lexer=Lexer +%option namespace=autocog::compiler::stl + +%% + +\n { newline(); } +[ \t\r]+ { update_location(); } + +"//"[^\n]* { update_location(); /* line comment */ } +"/*"([^*]|\*[^/])*"*/" { + for (size_t i = 0; i < size(); ++i) { + if (text()[i] == '\n') newline(); + else column_number++; + } +} + +// Keywords +"define" { update_location(); return (int)TokenType::DEFINE; } +"argument" { update_location(); return (int)TokenType::ARGUMENT; } +"record" { update_location(); return (int)TokenType::RECORD; } +"import" { update_location(); return (int)TokenType::IMPORT; } +"export" { update_location(); return (int)TokenType::EXPORT; } +"alias" { update_location(); return (int)TokenType::ALIAS; } +"prompt" { update_location(); return (int)TokenType::PROMPT; } +"channel" { update_location(); return (int)TokenType::CHANNEL; } +"flow" { update_location(); return (int)TokenType::FLOW; } +"return" { update_location(); return (int)TokenType::RETURN; } +"annotate" { update_location(); return (int)TokenType::ANNOTATE; } +"to" { update_location(); return (int)TokenType::TO; } +"from" { update_location(); return (int)TokenType::FROM; } +"get" { update_location(); return (int)TokenType::GET; } +"use" { update_location(); return (int)TokenType::USE; } +"call" { update_location(); return (int)TokenType::CALL; } +"mapped" { update_location(); return (int)TokenType::MAPPED; } +"ravel" { update_location(); return (int)TokenType::RAVEL; } +"bind" { update_location(); return (int)TokenType::BIND; } +"wrap" { update_location(); return (int)TokenType::WRAP; } +"prune" { update_location(); return (int)TokenType::PRUNE; } +"as" { update_location(); return (int)TokenType::AS; } +"is" { update_location(); return (int)TokenType::IS; } +"search" { update_location(); return (int)TokenType::SEARCH; } +"text" { update_location(); return (int)TokenType::TEXT; } +"select" { update_location(); return (int)TokenType::SELECT; } +"repeat" { update_location(); return (int)TokenType::REPEAT; } +"enum" { update_location(); return (int)TokenType::ENUM; } +"true" { update_location(); return (int)TokenType::BOOLEAN_LITERAL; } +"false" { update_location(); return (int)TokenType::BOOLEAN_LITERAL; } + +// Identifiers and literals +[a-zA-Z_][a-zA-Z0-9_]* { update_location(); return (int)TokenType::IDENTIFIER; } +[0-9]+\.[0-9]+([eE][+-]?[0-9]+)?|[0-9]+[eE][+-]?[0-9]+ { update_location(); return (int)TokenType::FLOAT_LITERAL; } +[0-9]+ { update_location(); return (int)TokenType::INTEGER_LITERAL; } + +[fF]\"([^"\\]|\\.)*\" { update_location(); return (int)TokenType::STRING_LITERAL; } +[fF]'([^'\\]|\\.)*' { update_location(); return (int)TokenType::STRING_LITERAL; } +\"([^"\\]|\\.)*\" { update_location(); return (int)TokenType::STRING_LITERAL; } +'([^'\\]|\\.)*' { update_location(); return (int)TokenType::STRING_LITERAL; } + +// Operators and delimiters +// IMPORTANT: Multi-character operators must come before single-character ones +"<=" { update_location(); return (int)TokenType::LTEQ; } +">=" { update_location(); return (int)TokenType::GTEQ; } +"==" { update_location(); return (int)TokenType::EQEQ; } +"!=" { update_location(); return (int)TokenType::BANGEQ; } +"&&" { update_location(); return (int)TokenType::AMPAMP; } +"||" { update_location(); return (int)TokenType::PIPEPIPE; } + +"{" { update_location(); return (int)TokenType::LBRACE; } +"}" { update_location(); return (int)TokenType::RBRACE; } +"[" { update_location(); return (int)TokenType::LSQUARE; } +"]" { update_location(); return (int)TokenType::RSQUARE; } +"(" { update_location(); return (int)TokenType::LPAREN; } +")" { update_location(); return (int)TokenType::RPAREN; } +";" { update_location(); return (int)TokenType::SEMICOLON; } +":" { update_location(); return (int)TokenType::COLON; } +"," { update_location(); return (int)TokenType::COMMA; } +"." { update_location(); return (int)TokenType::DOT; } +"=" { update_location(); return (int)TokenType::EQUAL; } +"+" { update_location(); return (int)TokenType::PLUS; } +"-" { update_location(); return (int)TokenType::MINUS; } +"*" { update_location(); return (int)TokenType::STAR; } +"/" { update_location(); return (int)TokenType::SLASH; } +"%" { update_location(); return (int)TokenType::PERCENT; } +"<" { update_location(); return (int)TokenType::LT; } +">" { update_location(); return (int)TokenType::GT; } +"!" { update_location(); return (int)TokenType::BANG; } +"?" { update_location(); return (int)TokenType::QUESTION; } + +. { + update_location(); + return (int)TokenType::ERROR; +} + +<> { return (int)TokenType::END_OF_FILE; } + +%% diff --git a/libs/autocog/compiler/stl/location.hxx b/libs/autocog/compiler/stl/location.hxx new file mode 100644 index 0000000..324c9c2 --- /dev/null +++ b/libs/autocog/compiler/stl/location.hxx @@ -0,0 +1,20 @@ +#ifndef AUTOCOG_COMPILER_STL_LOCATION_HXX +#define AUTOCOG_COMPILER_STL_LOCATION_HXX + +namespace autocog::compiler::stl { + +struct SourceLocation { + int fid; + int line; + int column; + unsigned offset; +}; + +struct SourceRange { + SourceLocation start; + SourceLocation stop; +}; + +} + +#endif /* AUTOCOG_COMPILER_STL_LOCATION_HXX */ diff --git a/libs/autocog/compiler/stl/parser-state.cxx b/libs/autocog/compiler/stl/parser-state.cxx new file mode 100644 index 0000000..ea6c5b4 --- /dev/null +++ b/libs/autocog/compiler/stl/parser-state.cxx @@ -0,0 +1,63 @@ + +#include "autocog/compiler/stl/parser.hxx" + +#include + +#include + +#if VERBOSE +# include +#endif + +namespace autocog::compiler::stl { + +ParserState::ParserState( + int fileid, + std::string const & source_ +) : + stream(source_), + lexer(stream), + previous(), + current(lexer.advance()) +{ + lexer.set_file_id(fileid); +} + +ParserState::ParserState( + std::string const & source_ +) : + ParserState(-1, source_) +{} + +void ParserState::advance() { + previous = current; + current = lexer.advance(); +} + +bool ParserState::check(TokenType type) { + return current.type == type; +} + +bool ParserState::match(TokenType type) { + if (check(type)) { + advance(); + return true; + } else { + return false; + } +} + +void ParserState::expect(TokenType type, std::string context) { + if (!match(type)) { + std::ostringstream oss; + oss << "Expected token `" << token_type_name(type) << "` but found `" << token_type_name(current.type) << "` " << context; + throw_error(oss.str()); + } +} + +void ParserState::throw_error(std::string msg) { + throw ParseError(msg, current.location); +} + +} + diff --git a/libs/autocog/compiler/stl/parser.cxx b/libs/autocog/compiler/stl/parser.cxx new file mode 100644 index 0000000..fa35ac4 --- /dev/null +++ b/libs/autocog/compiler/stl/parser.cxx @@ -0,0 +1,161 @@ + +#include "autocog/compiler/stl/parser.hxx" +#include "autocog/compiler/stl/diagnostic.hxx" + +#include + +#include + +#if VERBOSE +# include +#endif + +namespace autocog::compiler::stl { + +ParseError::ParseError( + std::string msg, + SourceLocation loc +) : + message(std::move(msg)), + location(loc) +{} + +const char * ParseError::what() const noexcept { + return message.c_str(); +} + +Parser::Parser( + std::list & diagnostics_, + std::unordered_map & fileids_, + std::list const & search_paths_, + std::list & programs_, + std::list const & filepaths +) : + search_paths(search_paths_), + diagnostics(diagnostics_), + fileids(fileids_), + programs(programs_), + queue() +{ + for (auto & filepath: filepaths) { + queue.push(filepath); + } +} + +static std::string file_lookup(std::string const & filepath, std::list const & search_paths) { + std::string found_path; + if (std::filesystem::exists(filepath)) { + found_path = filepath; + } else { + for (auto const & search_path : search_paths) { + std::filesystem::path full_path = std::filesystem::path(search_path) / filepath; + if (std::filesystem::exists(full_path)) { + found_path = full_path.string(); + break; + } + } + } + return found_path; +} + +void queue_imports(ast::Data const & program, std::queue & queue) { + for (auto & stmt: program.statements) { + if (stmt.index() == 0) { + auto & import = std::get<0>(stmt); + std::string const & file = import.data.file; + bool has_stl_extension = file.size() >= 4 && ( file.rfind(".stl") == file.size() - 4 ); + if (has_stl_extension) queue.push(file); + } + } +} + +#define DEBUG_Parser_parse VERBOSE && 0 + +void Parser::parse() { +#if DEBUG_Parser_parse + std::cerr << "ENTER Parser::parse()" << std::endl; +#endif + while (!queue.empty()) { + std::string filepath = queue.front(); + queue.pop(); + +#if DEBUG_Parser_parse + std::cerr << " filepath = " << filepath << std::endl; +#endif + + if (fileids.find(filepath) != fileids.end()) continue; + int fid = fileids.size(); + fileids.emplace(filepath, fid); + + std::string found_path = file_lookup(filepath, search_paths); + if (found_path.empty()) { + std::ostringstream oss; + oss << "Cannot find file: `" << filepath << "`"; + diagnostics.emplace_back(DiagnosticLevel::Error, oss.str()); + } else { + std::ifstream file(found_path); + std::string source((std::istreambuf_iterator(file)), std::istreambuf_iterator()); + file.close(); + + parse(fid, filepath, source); + queue_imports(programs.back().data, queue); + } + } +#if DEBUG_Parser_parse + std::cerr << "LEAVE Parser::parse()" << std::endl; +#endif +} + +void clean_raw_string(std::string raw_text, ast::Data & data) { + if (raw_text.empty()) { + throw std::runtime_error("Found empty string literal!"); + } + data.is_format = (raw_text[0] == 'f' || raw_text[0] == 'F'); + if (data.is_format) { + raw_text = raw_text.substr(1, raw_text.length() - 1); + } + if (raw_text.size() < 2) { + throw std::runtime_error("Found string literal with less than 2 characters but expect to find quotes!"); + } + if (raw_text[0] != '"' || raw_text[raw_text.length()-1] != '"') { + throw std::runtime_error("Found string literal without leading and ending quotes!"); + } + data.value = raw_text.substr(1, raw_text.length() - 2); +} + +static std::string get_line(std::string const & source, int line_pos) { + if (line_pos <= 0) throw std::runtime_error("In get_line(): line number must be greater than 0"); + std::stringstream ss(source); + int cnt = 0; + std::string line; + while (std::getline(ss, line)) { + cnt++; + if (cnt == line_pos) return line; + } + throw std::runtime_error("In get_line(): line number must be less than the number of lines in the file"); +} + +void Parser::parse(int fid, std::string const & name, std::string const & source) { + ParserState state(fid, source); + programs.emplace_back(name, fid); + try { + parse(state, programs.back().data); + } catch (ParseError const & e) { + auto line = get_line(source, e.location.line); + diagnostics.emplace_back(DiagnosticLevel::Error, e.message, line, e.location); + } +} + + +bool Parser::parse_fragment( + std::string const & tag_, + std::string const & code +) { + switch (ast::str2tag(tag_)) { +#define X(etag,stag) case ast::Tag::etag: return parse_fragment(code); +#include "autocog/compiler/stl/ast/nodes.def" + default: throw std::runtime_error("Unrecognized ast::Tag!"); + } +} + +} diff --git a/libs/autocog/compiler/stl/parser.hxx b/libs/autocog/compiler/stl/parser.hxx new file mode 100644 index 0000000..786e4ff --- /dev/null +++ b/libs/autocog/compiler/stl/parser.hxx @@ -0,0 +1,155 @@ +#ifndef AUTOCOG_COMPILER_STL_PARSER_HXX +#define AUTOCOG_COMPILER_STL_PARSER_HXX + +#include "autocog/compiler/stl/token.hxx" +#include "autocog/compiler/stl/ast.hxx" + +#include "autocog_compiler_stl_lexer.hxx" //< Generated file + +#include +#include +#include +#include +#include + +namespace autocog::compiler::stl { + +struct ParseError : std::exception { + std::string message; + SourceLocation location; + + ParseError(std::string msg, SourceLocation loc); + + const char * what() const noexcept override; +}; + +class Lexer; +struct Diagnostic; + +struct ParserState { + std::istringstream stream; + Lexer lexer; + + Token previous; + Token current; + + ParserState(std::string const & source_); + ParserState(int fid, std::string const & source_); + + void advance(); + + bool check(TokenType type); + bool match(TokenType type); + void expect(TokenType type, std::string context); + + void throw_error(std::string msg); +}; + +class Parser { + private: + std::list const & search_paths; + std::list & diagnostics; + std::unordered_map & fileids; + std::list & programs; + std::queue queue; + + private: + void parse(int fid, std::string const & filename, std::string const & source); + + template + static void parse(ParserState &, ast::Data &) { + std::ostringstream oss; + oss << "Not implemented: autocog::compiler::stl::Parser::parse<" << ast::tag2str(tag) << ">(...)"; + throw std::runtime_error(oss.str()); + } + + static void parse_primary(ParserState & state, ast::Data & expr); + + template // TODO remove + static void parse_with_location(ParserState & state, ast::Node & node, std::optional start = std::nullopt) { + SourceLocation start_loc{start?start.value():state.current.location}; + parse(state, node.data); + node.location.emplace(SourceRange{start_loc, state.current.location}); + } + + template + static void parse(ParserState & state, ast::Node & node, std::optional start = std::nullopt) { + SourceLocation start_loc{start?start.value():state.current.location}; + parse(state, node.data); + node.location.emplace(SourceRange{start_loc, state.current.location}); + } + + /// For testing purpose + template + static bool parse_fragment(std::string const & code); + + public: + Parser( + std::list &, + std::unordered_map &, + std::list const &, + std::list &, + std::list const & + ); + + void parse(); + + /// For testing purpose + static bool parse_fragment(std::string const & tag, std::string const & code); +}; + +template +bool Parser::parse_fragment( + std::string const & code +) { + ParserState state(code); + ast::Data data; + parse(state, data); + return state.check(TokenType::END_OF_FILE); +} + +void clean_raw_string(std::string raw_text, ast::Data & data); + +template <> void Parser::parse (ParserState & state, ast::Data &); +template <> void Parser::parse (ParserState & state, ast::Data &); +template <> void Parser::parse (ParserState & state, ast::Data &); +template <> void Parser::parse (ParserState & state, ast::Data &); +template <> void Parser::parse (ParserState & state, ast::Data &); +template <> void Parser::parse (ParserState & state, ast::Data &); +template <> void Parser::parse (ParserState & state, ast::Data &); +template <> void Parser::parse (ParserState & state, ast::Data &); +template <> void Parser::parse (ParserState & state, ast::Data &); +template <> void Parser::parse (ParserState & state, ast::Data &); +template <> void Parser::parse (ParserState & state, ast::Data &); +template <> void Parser::parse (ParserState & state, ast::Data &); +template <> void Parser::parse (ParserState & state, ast::Data &); +template <> void Parser::parse (ParserState & state, ast::Data &); +template <> void Parser::parse (ParserState & state, ast::Data &); +template <> void Parser::parse (ParserState & state, ast::Data &); +template <> void Parser::parse (ParserState & state, ast::Data &); +template <> void Parser::parse (ParserState & state, ast::Data &); +template <> void Parser::parse (ParserState & state, ast::Data &); +template <> void Parser::parse (ParserState & state, ast::Data &); +template <> void Parser::parse (ParserState & state, ast::Data &); + +template <> void Parser::parse(ParserState & state, ast::Data &); +template <> void Parser::parse(ParserState & state, ast::Data &); +template <> void Parser::parse(ParserState & state, ast::Data &); +template <> void Parser::parse(ParserState & state, ast::Data &); +template <> void Parser::parse(ParserState & state, ast::Data &); + +template <> void Parser::parse(ParserState & state, ast::Data &); + +template <> void Parser::parse(ParserState & state, ast::Data &); + +template <> void Parser::parse(ParserState & state, ast::Data &); +template <> void Parser::parse(ParserState & state, ast::Data &); +template <> void Parser::parse(ParserState & state, ast::Data &); + +template <> void Parser::parse(ParserState & state, ast::Data &); +template <> void Parser::parse(ParserState & state, ast::Data &); +template <> void Parser::parse(ParserState & state, ast::Data & data); + +} + +#endif // AUTOCOG_COMPILER_STL_PARSER_HXX diff --git a/libs/autocog/compiler/stl/parser/alias.cxx b/libs/autocog/compiler/stl/parser/alias.cxx new file mode 100644 index 0000000..ef028d2 --- /dev/null +++ b/libs/autocog/compiler/stl/parser/alias.cxx @@ -0,0 +1,28 @@ + +#include "autocog/compiler/stl/parser.hxx" + +#if VERBOSE +# include +#endif + +namespace autocog::compiler::stl { + +#define DEBUG_Parser_Alias VERBOSE && 0 + +template <> +void Parser::parse(ParserState & state, ast::Data & data) { +#if DEBUG_Parser_Alias + std::cerr << "Parser::parse" << std::endl; +#endif + parse_with_location(state, data.target); + if (state.match(TokenType::AS)) { +#if DEBUG_Parser_Alias + std::cerr << " matched AS" << std::endl; +#endif + data.alias.emplace(); + parse_with_location(state, data.alias.value()); + } +} + +} + diff --git a/libs/autocog/compiler/stl/parser/annotate.cxx b/libs/autocog/compiler/stl/parser/annotate.cxx new file mode 100644 index 0000000..32cd19e --- /dev/null +++ b/libs/autocog/compiler/stl/parser/annotate.cxx @@ -0,0 +1,40 @@ + +#include "autocog/compiler/stl/parser.hxx" + +#if VERBOSE +# include +#endif + +namespace autocog::compiler::stl { + +template <> +void Parser::parse(ParserState & state, ast::Data & annotation) { + annotation.path.emplace(); + parse(state, annotation.path.value().data); + state.expect(TokenType::AS, " in annotation."); + parse(state, annotation.description.data); + state.expect(TokenType::SEMICOLON, " to end annotation."); +} + +template <> +void Parser::parse(ParserState & state, ast::Data & annotate) { + state.expect(TokenType::ANNOTATE, "when parsing Annotate statement."); + if (state.match(TokenType::LBRACE)) { + annotate.single_statement = false; + while (!state.match(TokenType::RBRACE)) { + annotate.annotations.emplace_back(); + auto & annotation = annotate.annotations.back().data; + parse(state, annotation); + } + } else { + annotate.single_statement = true; + annotate.annotations.emplace_back(); + auto & annotation = annotate.annotations.back().data; + + parse(state, annotation.description.data); + state.expect(TokenType::SEMICOLON, " when ending single statement annotation."); + } +} + +} + diff --git a/libs/autocog/compiler/stl/parser/assign.cxx b/libs/autocog/compiler/stl/parser/assign.cxx new file mode 100644 index 0000000..d23af32 --- /dev/null +++ b/libs/autocog/compiler/stl/parser/assign.cxx @@ -0,0 +1,18 @@ + +#include "autocog/compiler/stl/parser.hxx" + +#if VERBOSE +# include +#endif + +namespace autocog::compiler::stl { + +template <> +void Parser::parse(ParserState & state, ast::Data & data) { + parse(state, data.argument); + state.expect(TokenType::EQUAL, "to assign value to argument."); + parse(state, data.value); +} + +} + diff --git a/libs/autocog/compiler/stl/parser/call.cxx b/libs/autocog/compiler/stl/parser/call.cxx new file mode 100644 index 0000000..7ce25eb --- /dev/null +++ b/libs/autocog/compiler/stl/parser/call.cxx @@ -0,0 +1,23 @@ + +#include "autocog/compiler/stl/parser.hxx" + +#if VERBOSE +# include +#endif + +namespace autocog::compiler::stl { + +template <> +void Parser::parse(ParserState & state, ast::Data & call) { + parse(state, call.entry.data); + state.expect(TokenType::LBRACE, " to start call arguments."); + + while (!state.match(TokenType::RBRACE)) { + call.arguments.emplace_back(); + auto & argument = call.arguments.back().data; + parse(state, argument); + } +} + +} + diff --git a/libs/autocog/compiler/stl/parser/channel.cxx b/libs/autocog/compiler/stl/parser/channel.cxx new file mode 100644 index 0000000..a923983 --- /dev/null +++ b/libs/autocog/compiler/stl/parser/channel.cxx @@ -0,0 +1,22 @@ + +#include "autocog/compiler/stl/parser.hxx" + +#if VERBOSE +# include +#endif + +namespace autocog::compiler::stl { + +template <> +void Parser::parse(ParserState & state, ast::Data & data) { + state.expect(TokenType::CHANNEL, "when parsing Channel statement."); + state.expect(TokenType::LBRACE, " when starting to parse channel body."); + while (!state.match(TokenType::RBRACE)) { + data.links.emplace_back(); + auto & link = data.links.back().data; + parse(state, link); + } +} + +} + diff --git a/libs/autocog/compiler/stl/parser/clauses.cxx b/libs/autocog/compiler/stl/parser/clauses.cxx new file mode 100644 index 0000000..8661fae --- /dev/null +++ b/libs/autocog/compiler/stl/parser/clauses.cxx @@ -0,0 +1,65 @@ + +#include "autocog/compiler/stl/parser.hxx" + +#if VERBOSE +# include +#endif + +namespace autocog::compiler::stl { + +template <> +void Parser::parse(ParserState & state, ast::Data & clause) { + state.expect(TokenType::BIND, "when parsing Bind clause."); + state.expect(TokenType::LPAREN, "for bind clause mandatory source argument."); + parse(state, clause.source.data); + if (state.match(TokenType::COMMA)) { + clause.target.emplace(); + parse(state, clause.target.value().data); + } + state.expect(TokenType::RPAREN, "to end bind clause arguments."); +} + +template <> +void Parser::parse(ParserState & state, ast::Data & clause) { + state.expect(TokenType::PRUNE, "when parsing Prune clause."); + state.expect(TokenType::LPAREN, "for prune clause mandatory source argument."); + parse(state, clause.target.data); + state.expect(TokenType::RPAREN, "to end prune clause arguments."); +} + +template <> +void Parser::parse(ParserState & state, ast::Data & clause) { + state.expect(TokenType::RAVEL, "when parsing Ravel clause."); + if (state.match(TokenType::LPAREN)) { + clause.depth.emplace(); + parse(state, clause.depth.value().data); + if (state.match(TokenType::COMMA)) { + clause.target.emplace(); + parse(state, clause.target.value().data); + } + state.expect(TokenType::RPAREN, "to end ravel clause arguments."); + } +} + +template <> +void Parser::parse(ParserState & state, ast::Data & clause) { + state.expect(TokenType::MAPPED, "when parsing Mapped clause."); + if (state.match(TokenType::LPAREN)) { + clause.target.emplace(); + parse(state, clause.target.value().data); + state.expect(TokenType::RPAREN, "to end mapped clause arguments."); + } +} + +template <> +void Parser::parse(ParserState & state, ast::Data & clause) { + state.expect(TokenType::WRAP, "when parsing Wrap clause."); + if (state.match(TokenType::LPAREN)) { + clause.target.emplace(); + parse(state, clause.target.value().data); + state.expect(TokenType::RPAREN, "to end wrap clause arguments."); + } +} + +} + diff --git a/libs/autocog/compiler/stl/parser/define.cxx b/libs/autocog/compiler/stl/parser/define.cxx new file mode 100644 index 0000000..a7f38b1 --- /dev/null +++ b/libs/autocog/compiler/stl/parser/define.cxx @@ -0,0 +1,33 @@ + +#include "autocog/compiler/stl/parser.hxx" + +#if VERBOSE +# include +#endif + +namespace autocog::compiler::stl { + +#define DEBUG_Parser_Define VERBOSE && 0 + +template <> +void Parser::parse(ParserState & state, ast::Data & define) { +#if DEBUG_Parser_Define + std::cerr << "Parser::parse" << std::endl; +#endif + if (state.match(TokenType::DEFINE)) { + define.is_argument = false; + } else if (state.match(TokenType::ARGUMENT)) { + define.is_argument = true; + } else { + state.throw_error("Expect either `define` or `argument`."); + } + parse(state, define.name); + if (state.match(TokenType::EQUAL)) { + define.init.emplace(); + parse(state, define.init.value().data); + } + state.expect(TokenType::SEMICOLON, " to end define/argument statement."); +} + +} + diff --git a/libs/autocog/compiler/stl/parser/expression.cxx b/libs/autocog/compiler/stl/parser/expression.cxx new file mode 100644 index 0000000..f272925 --- /dev/null +++ b/libs/autocog/compiler/stl/parser/expression.cxx @@ -0,0 +1,202 @@ + +#include "autocog/compiler/stl/parser.hxx" + +#if VERBOSE +# include +#endif + +namespace autocog::compiler::stl { + +static ast::OpKind token_to_operator_kind(TokenType type) { + switch (type) { + case TokenType::BANG: return ast::OpKind::Not; + case TokenType::PLUS: return ast::OpKind::Add; + case TokenType::MINUS: return ast::OpKind::Sub; + case TokenType::STAR: return ast::OpKind::Mul; + case TokenType::SLASH: return ast::OpKind::Div; + case TokenType::PERCENT: return ast::OpKind::Mod; + case TokenType::AMPAMP: return ast::OpKind::And; + case TokenType::PIPEPIPE: return ast::OpKind::Or; + case TokenType::LT: return ast::OpKind::Lt; + case TokenType::GT: return ast::OpKind::Gt; + case TokenType::LTEQ: return ast::OpKind::Lte; + case TokenType::GTEQ: return ast::OpKind::Gte; + case TokenType::EQEQ: return ast::OpKind::Eq; + case TokenType::BANGEQ: return ast::OpKind::Neq; + default: return ast::OpKind::NOP; + } +} + +[[maybe_unused]] static int get_precedence(TokenType type) { + switch (type) { + case TokenType::STAR: + case TokenType::SLASH: + case TokenType::PERCENT: + return 10; // Multiplicative + + case TokenType::PLUS: + case TokenType::MINUS: + return 9; // Additive + + case TokenType::LT: + case TokenType::GT: + case TokenType::LTEQ: + case TokenType::GTEQ: + return 8; // Relational + + case TokenType::EQEQ: + case TokenType::BANGEQ: + return 7; // Equality + + case TokenType::AMPAMP: + return 4; // Logical AND + + case TokenType::PIPEPIPE: + return 3; // Logical OR + + case TokenType::QUESTION: + return 2; // Ternary conditional + + default: + return -1; // Not a binary operator + } +} + +static bool is_unary(TokenType tok) { + return tok == TokenType::BANG || tok == TokenType::MINUS; +} + +static bool is_binary(TokenType tok) { + auto kind = token_to_operator_kind(tok); + return kind != ast::OpKind::Not && kind != ast::OpKind::NOP; +} + +static bool is_primary(TokenType tok) { + return tok == TokenType::STRING_LITERAL || + tok == TokenType::INTEGER_LITERAL || + tok == TokenType::FLOAT_LITERAL || + tok == TokenType::BOOLEAN_LITERAL || + tok == TokenType::IDENTIFIER; +} + +[[maybe_unused]] static bool is_conditional(TokenType tok) { + return tok == TokenType::QUESTION; +} + +template <> +void Parser::parse(ParserState & state, ast::Data & identifier) { + state.expect(TokenType::IDENTIFIER, " when parsing identifier."); + identifier.name = state.previous.text; +} + +#define DEBUG_parse_primary VERBOSE && 0 + +void Parser::parse_primary(ParserState & state, ast::Data & expr) { +#if DEBUG_parse_primary + std::cerr << "parse_primary" << std::endl; +#endif + switch (state.current.type) { + case TokenType::IDENTIFIER: { + expr.expr.emplace<1>(); + parse(state, std::get<1>(expr.expr)); + break; + } + + case TokenType::INTEGER_LITERAL: { + state.advance(); + expr.expr.emplace<2>(); + auto & data = std::get<2>(expr.expr).data; + data.value = std::stoi(state.previous.text); + break; + } + + case TokenType::FLOAT_LITERAL: { + state.advance(); + expr.expr.emplace<3>(); + auto & data = std::get<3>(expr.expr).data; + data.value = std::stof(state.previous.text); + break; + } + + case TokenType::BOOLEAN_LITERAL: { + state.advance(); + expr.expr.emplace<4>(); + auto & data = std::get<4>(expr.expr).data; + data.value = (state.previous.text == "true"); + break; + } + + case TokenType::STRING_LITERAL: { + state.advance(); + expr.expr.emplace<5>(); + auto & data = std::get<5>(expr.expr).data; + clean_raw_string(state.previous.text, data); + break; + } + + default: + state.throw_error("Expected literal or identifier in expression."); + break; + } +} + +#define DEBUG_Parse_Expression VERBOSE && 0 + +template <> +void Parser::parse(ParserState & state, ast::Data & expr) { +#if DEBUG_Parse_Expression + std::cerr << "Parser::parse" << std::endl; +#endif + if (is_primary(state.current.type)) { + parse_primary(state, expr); + + } else if (is_unary(state.current.type)) { + state.advance(); + expr.expr.emplace<6>(); + auto & data = std::get<6>(expr.expr).data; + data.kind = token_to_operator_kind(state.previous.type); + if (data.kind == ast::OpKind::Sub) data.kind = ast::OpKind::Neg; + data.operand = std::make_unique(); + if (!is_primary(state.current.type) && state.current.type != TokenType::LPAREN ) { + state.throw_error("Unary operator expects primary or parenthesized operand!"); + } + parse(state, *(data.operand)); + } else if (state.match(TokenType::LPAREN)) { + auto operand = std::make_unique(); + parse(state, *operand); + if (is_binary(state.current.type)) { + state.advance(); + expr.expr.emplace<7>(); + auto & data = std::get<7>(expr.expr).data; + data.kind = token_to_operator_kind(state.previous.type); + data.lhs = std::move(operand); + data.rhs = std::make_unique(); + parse(state, *(data.rhs)); + + } else if (state.match(TokenType::QUESTION)) { + expr.expr.emplace<8>(); + auto & data = std::get<8>(expr.expr).data; + data.cond = std::move(operand); + data.e_true = std::make_unique(); + parse(state, *(data.e_true)); + state.expect(TokenType::COLON, " within conditional expression."); + data.e_false = std::make_unique(); + parse(state, *(data.e_false)); + + } else { + expr.expr.emplace<9>(); + auto & data = std::get<9>(expr.expr).data; + data.expr = std::move(operand); + + } + state.expect(TokenType::RPAREN, " to end parenthesized expression."); + + } else { + std::ostringstream oss; + oss << "Expression must be primary, unary, or parenthesized (for binary and conditional)! `" << token_type_name(state.current.type) << "`"; + state.throw_error(oss.str()); + } +} + +} + diff --git a/libs/autocog/compiler/stl/parser/field.cxx b/libs/autocog/compiler/stl/parser/field.cxx new file mode 100644 index 0000000..d73b29c --- /dev/null +++ b/libs/autocog/compiler/stl/parser/field.cxx @@ -0,0 +1,34 @@ + +#include "autocog/compiler/stl/parser.hxx" + +#if VERBOSE +# include +#endif + +namespace autocog::compiler::stl { + +template <> +void Parser::parse(ParserState & state, ast::Data & field) { + parse(state, field.name); + if (state.match(TokenType::LSQUARE)) { + field.lower.emplace(); + parse(state, field.lower.value()); + + if (state.match(TokenType::COLON)) { + field.upper.emplace(); + parse(state, field.upper.value()); + } + state.expect(TokenType::RSQUARE, " to close array dimension."); + } + state.expect(TokenType::IS, " between field name and type."); + if (state.check(TokenType::LBRACE)) { + field.type.emplace<2>(); + parse(state, std::get<2>(field.type)); + } else { + field.type.emplace<1>(); + parse(state, std::get<1>(field.type)); + } +} + +} + diff --git a/libs/autocog/compiler/stl/parser/fieldref.cxx b/libs/autocog/compiler/stl/parser/fieldref.cxx new file mode 100644 index 0000000..d94fb7d --- /dev/null +++ b/libs/autocog/compiler/stl/parser/fieldref.cxx @@ -0,0 +1,20 @@ + +#include "autocog/compiler/stl/parser.hxx" + +#if VERBOSE +# include +#endif + +namespace autocog::compiler::stl { + +template <> void Parser::parse(ParserState & state, ast::Data & fldref) { + if (!state.match(TokenType::DOT)) { + fldref.prompt.emplace(); + parse(state, fldref.prompt.value().data); + state.expect(TokenType::DOT, " after prompt in field reference."); + } + parse(state, fldref.field.data); +} + +} + diff --git a/libs/autocog/compiler/stl/parser/flow.cxx b/libs/autocog/compiler/stl/parser/flow.cxx new file mode 100644 index 0000000..b8decea --- /dev/null +++ b/libs/autocog/compiler/stl/parser/flow.cxx @@ -0,0 +1,44 @@ + +#include "autocog/compiler/stl/parser.hxx" + +#if VERBOSE +# include +#endif + +namespace autocog::compiler::stl { + +template <> +void Parser::parse(ParserState & state, ast::Data & edge) { + parse(state, edge.prompt.data); + if (state.match(TokenType::LSQUARE)) { + edge.limit.emplace(); + parse(state, edge.limit.value().data); + state.expect(TokenType::RSQUARE, "to end trip limit expression."); + } + if (state.match(TokenType::AS)) { + edge.label.emplace(); + parse(state, edge.label.value().data); + } + state.expect(TokenType::SEMICOLON, " to end flow-edge statement."); +} + +template <> +void Parser::parse(ParserState & state, ast::Data & data) { + state.expect(TokenType::FLOW, "when parsing Flow statement."); + if (state.match(TokenType::LBRACE)) { + data.short_form = false; + while (!state.match(TokenType::RBRACE)) { + data.edges.emplace_back(); + auto & edge = data.edges.back().data; + parse(state, edge); + } + } else { + data.short_form = true; + data.edges.emplace_back(); + auto & edge = data.edges.back().data; + parse(state, edge); + } +} + +} + diff --git a/libs/autocog/compiler/stl/parser/format.cxx b/libs/autocog/compiler/stl/parser/format.cxx new file mode 100644 index 0000000..196581e --- /dev/null +++ b/libs/autocog/compiler/stl/parser/format.cxx @@ -0,0 +1,87 @@ + +#include "autocog/compiler/stl/parser.hxx" + +#if VERBOSE +# include +#endif + +namespace autocog::compiler::stl { + +template <> +void Parser::parse(ParserState & state, ast::Data & type) { + state.expect(TokenType::LPAREN, ""); + state.expect(TokenType::STRING_LITERAL, " when parsing enumerators."); + type.enumerators.emplace_back(state.previous.text); + while (!state.match(TokenType::RPAREN)) { + state.expect(TokenType::COMMA, " when parsing enumerators."); + state.expect(TokenType::STRING_LITERAL, " when parsing enumerators."); + type.enumerators.emplace_back(state.previous.text); + } +} + +template <> +void Parser::parse(ParserState & state, [[maybe_unused]] ast::Data & type) { + if (state.match(TokenType::LPAREN)) { + // TODO how should we represent vocab source??? + state.expect(TokenType::RPAREN, "."); + } +} + +template <> +void Parser::parse(ParserState & state, ast::Data & type) { + if (state.current.type == TokenType::REPEAT) { + type.mode = ast::ChoiceKind::Repeat; + } else { + type.mode = ast::ChoiceKind::Select; + } + + state.expect(TokenType::LPAREN, " when parsing a choice format."); + parse(state, type.source); + state.expect(TokenType::RPAREN, " at the end of a choice format."); +} + +template <> +void Parser::parse(ParserState & state, ast::Data & data) { + switch (state.current.type) { + case TokenType::IDENTIFIER: { + data.type.emplace<1>(); + parse(state, std::get<1>(data.type)); + break; + } + case TokenType::TEXT: { + state.advance(); + data.type.emplace<2>(); + parse(state, std::get<2>(data.type)); + break; + } + case TokenType::ENUM: { + state.advance(); + data.type.emplace<3>(); + parse(state, std::get<3>(data.type)); + break; + } + case TokenType::REPEAT: + case TokenType::SELECT: { + state.advance(); + data.type.emplace<4>(); + parse(state, std::get<4>(data.type)); + break; + } + default: { + std::ostringstream oss; + oss << "Unexpected token `" << token_type_name(state.current.type) << "` while parsing top level statements of a Record."; + state.throw_error(oss.str()); + } + } + if (state.match(TokenType::LT)) { + do { + data.kwargs.emplace_back(); + parse(state, data.kwargs.back()); + } while (state.match(TokenType::COMMA)); + state.expect(TokenType::GT, " to close format arguments."); + } + state.expect(TokenType::SEMICOLON, " to finish format statement"); +} + +} + diff --git a/libs/autocog/compiler/stl/parser/import.cxx b/libs/autocog/compiler/stl/parser/import.cxx new file mode 100644 index 0000000..2f27fa2 --- /dev/null +++ b/libs/autocog/compiler/stl/parser/import.cxx @@ -0,0 +1,50 @@ + +#include "autocog/compiler/stl/parser.hxx" + +#if VERBOSE +# include +#endif + +namespace autocog::compiler::stl { + +#define DEBUG_Parser_Import VERBOSE && 1 + +template <> +void Parser::parse(ParserState & state, ast::Data & import) { +#if DEBUG_Parser_Import + std::cerr << "Parser::parse" << std::endl; +#endif + state.expect(TokenType::FROM, "when parsing Import statement."); + state.expect(TokenType::STRING_LITERAL, "when parsing import file path."); + + std::string raw_text = state.previous.text; + + // Reject f-strings for import paths + if (!raw_text.empty() && (raw_text[0] == 'f' || raw_text[0] == 'F')) { + state.throw_error("Format strings (f-strings) are not allowed for import paths."); + } + + // Strip quotes from regular string + std::string file_path = raw_text; + if (file_path.length() >= 2 && + ((file_path.front() == '"' && file_path.back() == '"') || + (file_path.front() == '\'' && file_path.back() == '\''))) { + file_path = file_path.substr(1, file_path.length() - 2); + } + import.file = file_path; + + state.expect(TokenType::IMPORT, " after file path in import statement."); + + do { + import.targets.emplace_back(); +#if DEBUG_Parser_Import + std::cerr << " alias #" << import.targets.size() << std::endl; +#endif + parse_with_location(state, import.targets.back()); + } while (state.match(TokenType::COMMA)); + + state.expect(TokenType::SEMICOLON, " to end import statement."); +} + +} + diff --git a/libs/autocog/compiler/stl/parser/kwarg.cxx b/libs/autocog/compiler/stl/parser/kwarg.cxx new file mode 100644 index 0000000..5e4a713 --- /dev/null +++ b/libs/autocog/compiler/stl/parser/kwarg.cxx @@ -0,0 +1,59 @@ + +#include "autocog/compiler/stl/parser.hxx" + +#if VERBOSE +# include +#endif + +namespace autocog::compiler::stl { + +template <> +void Parser::parse(ParserState & state, ast::Data & kwarg) { + state.expect(TokenType::IDENTIFIER, " for argument name."); + kwarg.name.data.name = state.previous.text; + if (state.match(TokenType::USE)) { + kwarg.source.emplace<1>(); + parse(state, std::get<1>(kwarg.source).data); + } else if (state.match(TokenType::GET)) { + kwarg.source.emplace<2>(); + parse(state, std::get<2>(kwarg.source).data); + } else if (state.match(TokenType::IS)) { + kwarg.source.emplace<3>(); + parse(state, std::get<3>(kwarg.source).data); + } else { + state.throw_error("Expected 'use', 'get', or 'is' to define call argument source."); + } + + while (!state.match(TokenType::SEMICOLON)) { + switch (state.current.type) { + case TokenType::BIND: + kwarg.clauses.emplace_back(std::in_place_index<0>); + parse(state, std::get<0>(kwarg.clauses.back()).data); + break; + case TokenType::RAVEL: + kwarg.clauses.emplace_back(std::in_place_index<1>); + parse(state, std::get<1>(kwarg.clauses.back()).data); + break; + case TokenType::WRAP: + kwarg.clauses.emplace_back(std::in_place_index<2>); + parse(state, std::get<2>(kwarg.clauses.back()).data); + break; + case TokenType::PRUNE: + kwarg.clauses.emplace_back(std::in_place_index<3>); + parse(state, std::get<3>(kwarg.clauses.back()).data); + break; + case TokenType::MAPPED: + kwarg.clauses.emplace_back(std::in_place_index<4>); + parse(state, std::get<4>(kwarg.clauses.back()).data); + break; + default: { + std::ostringstream oss; + oss << "Call channel argument's clauses can only be `bind`, `ravel`, `wrap`, `prune`, or `mapped`. Found " << token_type_name(state.current.type) << "`."; + state.throw_error(oss.str()); + } + } + } +} + +} + diff --git a/libs/autocog/compiler/stl/parser/link.cxx b/libs/autocog/compiler/stl/parser/link.cxx new file mode 100644 index 0000000..e0af9ea --- /dev/null +++ b/libs/autocog/compiler/stl/parser/link.cxx @@ -0,0 +1,55 @@ + +#include "autocog/compiler/stl/parser.hxx" + +#if VERBOSE +# include +#endif + +namespace autocog::compiler::stl { + +template <> +void Parser::parse(ParserState & state, ast::Data & link) { + parse(state, link.target.data); + + if (state.match(TokenType::USE)) { + link.source.emplace<1>(); + parse(state, std::get<1>(link.source).data); + } else if (state.match(TokenType::GET)) { + link.source.emplace<2>(); + parse(state, std::get<2>(link.source).data); + } else if (state.match(TokenType::IS)) { + link.source.emplace<3>(); + parse(state, std::get<3>(link.source).data); + } else if (state.match(TokenType::CALL)) { + link.source.emplace<4>(); + parse(state, std::get<4>(link.source).data); + } else { + state.throw_error("Expected 'use', 'get', 'is' or 'call' to define channel's source."); + } + + while (!state.match(TokenType::SEMICOLON)) { + switch (state.current.type) { + case TokenType::BIND: + link.clauses.emplace_back(std::in_place_index<0>); + parse(state, std::get<0>(link.clauses.back()).data); + break; + case TokenType::RAVEL: + link.clauses.emplace_back(std::in_place_index<1>); + parse(state, std::get<1>(link.clauses.back()).data); + break; + case TokenType::PRUNE: + link.clauses.emplace_back(std::in_place_index<2>); + parse(state, std::get<2>(link.clauses.back()).data); + break; + case TokenType::WRAP: + link.clauses.emplace_back(std::in_place_index<3>); + parse(state, std::get<3>(link.clauses.back()).data); + break; + default: + state.throw_error("Channel clauses can only be `bind`, `ravel`, `prune`, or `wrap`"); + } + } +} + +} + diff --git a/libs/autocog/compiler/stl/parser/objectref.cxx b/libs/autocog/compiler/stl/parser/objectref.cxx new file mode 100644 index 0000000..29cfeda --- /dev/null +++ b/libs/autocog/compiler/stl/parser/objectref.cxx @@ -0,0 +1,24 @@ + +#include "autocog/compiler/stl/parser.hxx" + +#if VERBOSE +# include +#endif + +namespace autocog::compiler::stl { + +template <> void Parser::parse(ParserState & state, ast::Data & pref) { + parse(state, pref.name); + if (state.match(TokenType::LT)) { + if (!state.check(TokenType::GT)) { + do { + pref.config.emplace_back(); + parse(state, pref.config.back()); + } while (state.match(TokenType::COMMA)); + } + state.expect(TokenType::GT, "to close compile arguments."); + } +} + +} + diff --git a/libs/autocog/compiler/stl/parser/path.cxx b/libs/autocog/compiler/stl/parser/path.cxx new file mode 100644 index 0000000..9d19e8b --- /dev/null +++ b/libs/autocog/compiler/stl/parser/path.cxx @@ -0,0 +1,39 @@ + +#include "autocog/compiler/stl/parser.hxx" + +#if VERBOSE +# include +#endif + +namespace autocog::compiler::stl { + +template <> +void Parser::parse(ParserState & state, ast::Data & path) { + do { + state.expect(TokenType::IDENTIFIER, " when parsing path field."); + + path.steps.emplace_back(); + auto & step = path.steps.back().data; + step.field.data.name = state.previous.text; + step.is_range = false; + if (state.match(TokenType::LSQUARE)) { + if (!state.check(TokenType::COLON)) { + step.lower.emplace(); + parse(state, step.lower.value().data); + } + + if (state.match(TokenType::COLON)) { + step.is_range = true; + if (!state.check(TokenType::RSQUARE)) { + step.upper.emplace(); + parse(state, step.upper.value().data); + } + } + + state.expect(TokenType::RSQUARE, " to close array access."); + } + } while (state.match(TokenType::DOT)); +} + +} + diff --git a/libs/autocog/compiler/stl/parser/program.cxx b/libs/autocog/compiler/stl/parser/program.cxx new file mode 100644 index 0000000..7a4c593 --- /dev/null +++ b/libs/autocog/compiler/stl/parser/program.cxx @@ -0,0 +1,72 @@ + +#include "autocog/compiler/stl/parser.hxx" + +#if VERBOSE +# include +#endif + +namespace autocog::compiler::stl { + +template <> +void Parser::parse(ParserState & state, ast::Data & data) { + while (!state.match(TokenType::END_OF_FILE)) { + switch (state.current.type) { + case TokenType::FROM: { + data.statements.emplace_back(std::in_place_index<0>); + auto & node = std::get<0>(data.statements.back()); + parse(state, node); + break; + } + case TokenType::ALIAS: + case TokenType::EXPORT: { + auto start = state.current.location; + data.statements.emplace_back(std::in_place_index<1>); + auto & node = std::get<1>(data.statements.back()); + node.data.is_export = (state.current.type == TokenType::EXPORT); + state.advance(); + parse(state, node, start); + state.expect(TokenType::SEMICOLON, "to end alias/export statements"); + break; + } + case TokenType::DEFINE: + case TokenType::ARGUMENT: { + data.statements.emplace_back(std::in_place_index<2>); + auto & node = std::get<2>(data.statements.back()); + parse(state, node); + break; + } + case TokenType::ANNOTATE: { + data.statements.emplace_back(std::in_place_index<3>); + auto & node = std::get<3>(data.statements.back()); + parse(state, node); + break; + } + case TokenType::SEARCH: { + data.statements.emplace_back(std::in_place_index<4>); + auto & node = std::get<4>(data.statements.back()); + parse(state, node); + break; + } + case TokenType::RECORD: { + data.statements.emplace_back(std::in_place_index<5>); + auto & node = std::get<5>(data.statements.back()); + parse(state, node); + break; + } + case TokenType::PROMPT: { + data.statements.emplace_back(std::in_place_index<6>); + auto & node = std::get<6>(data.statements.back()); + parse(state, node); + break; + } + default: { + std::ostringstream oss; + oss << "Unexpected token `" << token_type_name(state.current.type) << "` while parsing statement of Program."; + state.throw_error(oss.str()); + } + } + } +} + +} + diff --git a/libs/autocog/compiler/stl/parser/prompt.cxx b/libs/autocog/compiler/stl/parser/prompt.cxx new file mode 100644 index 0000000..ef7b71e --- /dev/null +++ b/libs/autocog/compiler/stl/parser/prompt.cxx @@ -0,0 +1,79 @@ + +#include "autocog/compiler/stl/parser.hxx" + +#if VERBOSE +# include +#endif + +namespace autocog::compiler::stl { + +template <> +void Parser::parse(ParserState & state, ast::Data & data) { + state.expect(TokenType::PROMPT, "when starting a Prompt."); + parse(state, data.name); + state.expect(TokenType::LBRACE, "when defining a Prompt."); + while (!state.match(TokenType::RBRACE)) { + switch (state.current.type) { + case TokenType::DEFINE: + case TokenType::ARGUMENT: { + data.constructs.emplace_back(std::in_place_index<0>); + auto & node = std::get<0>(data.constructs.back()); + parse(state, node); + break; + } + case TokenType::ANNOTATE: { + data.constructs.emplace_back(std::in_place_index<1>); + auto & node = std::get<1>(data.constructs.back()); + parse(state, node); + break; + } + case TokenType::SEARCH: { + data.constructs.emplace_back(std::in_place_index<2>); + auto & node = std::get<2>(data.constructs.back()); + parse(state, node); + break; + } + case TokenType::ALIAS: { + auto start = state.current.location; + data.constructs.emplace_back(std::in_place_index<3>); + auto & node = std::get<3>(data.constructs.back()); + state.advance(); + parse(state, node, start); + break; + } + case TokenType::CHANNEL: { + data.constructs.emplace_back(std::in_place_index<4>); + auto & node = std::get<4>(data.constructs.back()); + parse(state, node); + break; + } + case TokenType::FLOW: { + data.constructs.emplace_back(std::in_place_index<5>); + auto & node = std::get<5>(data.constructs.back()); + parse(state, node); + break; + } + case TokenType::RETURN: { + data.constructs.emplace_back(std::in_place_index<6>); + auto & node = std::get<6>(data.constructs.back()); + parse(state, node); + break; + } + case TokenType::IS: { + auto start = state.current.location; + data.fields.emplace(); + state.advance(); + parse(state, data.fields.value(), start); + break; + } + default: { + std::ostringstream oss; + oss << "Unexpected token `" << token_type_name(state.current.type) << "` while parsing top level statements of a Prompt."; + state.throw_error(oss.str()); + } + } + } +} + +} + diff --git a/libs/autocog/compiler/stl/parser/record.cxx b/libs/autocog/compiler/stl/parser/record.cxx new file mode 100644 index 0000000..f86c1eb --- /dev/null +++ b/libs/autocog/compiler/stl/parser/record.cxx @@ -0,0 +1,66 @@ + +#include "autocog/compiler/stl/parser.hxx" + +#if VERBOSE +# include +#endif + +namespace autocog::compiler::stl { + +template <> +void Parser::parse(ParserState & state, ast::Data & data) { + state.expect(TokenType::RECORD, "when starting a Record."); + parse(state, data.name); + state.expect(TokenType::LBRACE, "when defining a Record."); + while (!state.match(TokenType::RBRACE)) { + switch (state.current.type) { + case TokenType::DEFINE: + case TokenType::ARGUMENT: { + data.constructs.emplace_back(std::in_place_index<0>); + auto & node = std::get<0>(data.constructs.back()); + parse(state, node); + break; + } + case TokenType::ANNOTATE: { + data.constructs.emplace_back(std::in_place_index<1>); + auto & node = std::get<1>(data.constructs.back()); + parse(state, node); + break; + } + case TokenType::SEARCH: { + data.constructs.emplace_back(std::in_place_index<2>); + auto & node = std::get<2>(data.constructs.back()); + parse(state, node); + break; + } + case TokenType::ALIAS: { + auto start = state.current.location; + data.constructs.emplace_back(std::in_place_index<3>); + auto & node = std::get<3>(data.constructs.back()); + state.advance(); + parse(state, node, start); + break; + } + case TokenType::IS: { + auto start = state.current.location; + state.advance(); + if (state.check(TokenType::LBRACE)) { + data.record.emplace<1>(); + parse(state, std::get<1>(data.record), start); + } else { + data.record.emplace<2>(); + parse(state, std::get<2>(data.record), start); + } + break; + } + default: { + std::ostringstream oss; + oss << "Unexpected token `" << token_type_name(state.current.type) << "` while parsing top level statements of a Record."; + state.throw_error(oss.str()); + } + } + } +} + +} + diff --git a/libs/autocog/compiler/stl/parser/return.cxx b/libs/autocog/compiler/stl/parser/return.cxx new file mode 100644 index 0000000..aef8cdd --- /dev/null +++ b/libs/autocog/compiler/stl/parser/return.cxx @@ -0,0 +1,79 @@ + +#include "autocog/compiler/stl/parser.hxx" + +#if VERBOSE +# include +#endif + +namespace autocog::compiler::stl { + +template <> +void Parser::parse(ParserState & state, ast::Data & data) { + if (!state.check(TokenType::USE) && !state.check(TokenType::IS)) { + data.alias.emplace(); + parse(state, data.alias.value().data); + } + if (state.match(TokenType::USE)) { + data.source.emplace<1>(); + parse(state, std::get<1>(data.source).data); + } else if (state.match(TokenType::IS)) { + data.source.emplace<2>(); + parse(state, std::get<2>(data.source).data); + } + + while (!state.match(TokenType::SEMICOLON)) { + switch (state.current.type) { + case TokenType::BIND: + data.clauses.emplace_back(std::in_place_index<0>); + parse(state, std::get<0>(data.clauses.back()).data); + break; + case TokenType::RAVEL: + data.clauses.emplace_back(std::in_place_index<1>); + parse(state, std::get<1>(data.clauses.back()).data); + break; + case TokenType::WRAP: + data.clauses.emplace_back(std::in_place_index<2>); + parse(state, std::get<2>(data.clauses.back()).data); + break; + case TokenType::PRUNE: + data.clauses.emplace_back(std::in_place_index<3>); + parse(state, std::get<3>(data.clauses.back()).data); + break; + default: { + std::ostringstream oss; + oss << "Return clauses can only be `bind`, `ravel`, `wrap`, or `prune`. Found " << token_type_name(state.current.type) << "`."; + state.throw_error(oss.str()); + } + } + } +} + +template <> +void Parser::parse(ParserState & state, ast::Data & data) { + state.expect(TokenType::RETURN, "when parsing Return statement."); + if (!state.check(TokenType::LBRACE) && !state.check(TokenType::USE) && !state.check(TokenType::SEMICOLON)) { + data.label.emplace(); + parse(state, data.label.value().data); + } + + if (state.check(TokenType::USE)) { + data.short_form = true; + data.fields.emplace_back(); + auto & rfld = data.fields.back().data; + parse(state, rfld); + + } else if (state.match(TokenType::LBRACE)) { + data.short_form = true; + while (!state.match(TokenType::RBRACE)) { + data.fields.emplace_back(); + auto & rfld = data.fields.back().data; + parse(state, rfld); + } + + } else { + state.expect(TokenType::SEMICOLON, " for empty return."); + } +} + +} + diff --git a/libs/autocog/compiler/stl/parser/search.cxx b/libs/autocog/compiler/stl/parser/search.cxx new file mode 100644 index 0000000..6697356 --- /dev/null +++ b/libs/autocog/compiler/stl/parser/search.cxx @@ -0,0 +1,34 @@ + +#include "autocog/compiler/stl/parser.hxx" + +#if VERBOSE +# include +#endif + +namespace autocog::compiler::stl { + +template <> +void Parser::parse(ParserState & state, ast::Data & param) { + state.expect(TokenType::IDENTIFIER, "parameter locator starts with an identifier."); + param.locator.emplace_back(state.previous.text); + while (state.match(TokenType::DOT)) { + state.expect(TokenType::IDENTIFIER, "parameter locator needs identifier after '.'."); + param.locator.emplace_back(state.previous.text); + } + state.expect(TokenType::IS, " to set value of search parameter."); + parse(state, param.value.data); + state.expect(TokenType::SEMICOLON, " to finish search parameter assignment."); +} + +template <> +void Parser::parse(ParserState & state, ast::Data & search) { + state.expect(TokenType::SEARCH, "when parsing Search statement."); + state.expect(TokenType::LBRACE, " when starting to parse a block of search parameters."); + while (!state.match(TokenType::RBRACE)) { + search.params.emplace_back(); + parse(state, search.params.back().data); + } +} + +} + diff --git a/libs/autocog/compiler/stl/parser/struct.cxx b/libs/autocog/compiler/stl/parser/struct.cxx new file mode 100644 index 0000000..ab83b72 --- /dev/null +++ b/libs/autocog/compiler/stl/parser/struct.cxx @@ -0,0 +1,21 @@ + +#include "autocog/compiler/stl/parser.hxx" + +#if VERBOSE +# include +#endif + +namespace autocog::compiler::stl { + +template <> +void Parser::parse(ParserState & state, ast::Data & data) { + state.expect(TokenType::LBRACE, " when starting to parse struct body."); + while (!state.match(TokenType::RBRACE)) { + data.fields.emplace_back(std::make_unique()); + auto & field = data.fields.back()->data; + parse(state, field); + } +} + +} + diff --git a/libs/autocog/compiler/stl/symbol-scanner.cxx b/libs/autocog/compiler/stl/symbol-scanner.cxx new file mode 100644 index 0000000..27c1b30 --- /dev/null +++ b/libs/autocog/compiler/stl/symbol-scanner.cxx @@ -0,0 +1,186 @@ + +#include "autocog/compiler/stl/symbol-scanner.hxx" +#include "autocog/compiler/stl/driver.hxx" + +#include +#include + +#if VERBOSE +# include +#endif + +namespace autocog::compiler::stl { + +SymbolScanner::SymbolScanner( + Driver & driver_ +) : + driver(driver_), + fileid(std::nullopt), + scopes(), + shortcut_flag(false) +{} + +static std::string join(const std::vector& vec, const std::string& sep) { + if (vec.empty()) return ""; + + std::ostringstream oss; + auto it = vec.begin(); + oss << *it++; + for (; it != vec.end(); ++it) { + oss << sep << *it; + } + return oss.str(); +} + +std::string SymbolScanner::scope() const { + return join(this->scopes, "::"); +} + +template <> +void SymbolScanner::pre(ast::Program const & node) { + this->fileid = node.data.fid; + auto scope = std::to_string(node.data.fid); + this->scopes.push_back(scope); +} + +template <> +void SymbolScanner::post(ast::Program const &) { + this->scopes.pop_back(); + this->fileid = std::nullopt; +} + +template <> +void SymbolScanner::pre(ast::Import const & node) { + this->shortcut_flag = true; + auto scope = this->scope(); + auto const & filename = node.data.file; + bool has_stl_ext = filename.size() >= 4 && ( filename.rfind(".stl") == filename.size() - 4 ); + bool has_py_ext = filename.size() >= 3 && ( filename.rfind(".py") == filename.size() - 3 ); + if (has_stl_ext) { + auto fid = this->driver.fileid(filename); + if (!fid) { + throw std::runtime_error("STL file should have been parsed!!!"); + } + for (auto const & alias_node: node.data.targets) { + auto const & target = alias_node.data.target; + std::string name; + if (alias_node.data.alias) { + name = alias_node.data.alias.value().data.name; + } else { + if (target.data.config.size() > 0) { + this->driver.emit_error("Imported parametrized objects must be given a local alias (\"as\" keyword).", alias_node.location); + continue; + } + name = target.data.name.data.name; + } + auto alias = scope + "::" + name; + auto sym = UnresolvedImport(fid.value(), alias, target, node); + + if (!this->driver.tables.symbols.emplace(alias, sym).second) { + this->driver.emit_error("Already have a object with name " + name + " in scope " + scope + ".", alias_node.location); + return; + } + } + } else if (has_py_ext) { + for (auto const & alias_node: node.data.targets) { + auto const & target = alias_node.data.target; + std::string name; + if (alias_node.data.alias) { + name = alias_node.data.alias.value().data.name; + } else { + if (target.data.config.size() > 0) { + this->driver.emit_error("Imported parametrized objects must be given a local alias (\"as\" keyword).", alias_node.location); + continue; + } + name = target.data.name.data.name; + } + auto alias = scope + "::" + name; + auto sym = PythonSymbol(filename, alias, target); + + if (!this->driver.tables.symbols.emplace(alias, sym).second) { + this->driver.emit_error("Already have a object with name " + name + " in scope " + scope + ".", alias_node.location); + return; + } + } + } else { + this->driver.emit_error("Imported file with unknown extension.", node.location); + } +} + +template <> +void SymbolScanner::pre(ast::Alias const & node) { + this->shortcut_flag = true; + std::string name; + if (node.data.alias) { + name = node.data.alias.value().data.name; + } else { + if (!node.data.is_export) { + this->driver.emit_error("Alias without specifying an alias name.", node.location); + return; + } else if (node.data.target.data.config.size() > 0) { + this->driver.emit_error("Export parametrized object without specifying an alias name.", node.location); + return; + } + name = node.data.target.data.name.data.name; + } + auto scope = this->scope(); + auto alias = scope + "::" + name; + auto sym = UnresolvedAlias(this->fileid.value(), alias, node.data.target, node); + + if (!this->driver.tables.symbols.emplace(alias, sym).second) { + this->driver.emit_error("Already have a object with name " + name + " in scope " + scope + ".", node.location); + return; + } +} + +template <> +void SymbolScanner::pre(ast::Define const & node) { + this->shortcut_flag = true; + auto name = node.data.name.data.name; + auto scope = this->scope(); + auto alias = scope + "::" + name; + auto sym = DefineSymbol(node, alias); + + if (!this->driver.tables.symbols.emplace(alias, sym).second) { + this->driver.emit_error("Already have a object with name " + name + " in scope " + scope + ".", node.data.name.location); + return; + } +} + +template <> +void SymbolScanner::pre(ast::Record const & node) { + auto scope = this->scope(); + auto name = node.data.name.data.name; + this->scopes.push_back(name); + auto alias = this->scope(); + RecordSymbol sym(node, alias); + if (!this->driver.tables.symbols.emplace(alias, sym).second) { + this->driver.emit_error("Already have a object with name " + name + " in scope " + scope + ".", node.data.name.location); + return; + } +} + +template <> +void SymbolScanner::post(ast::Record const &) { + this->scopes.pop_back(); +} + +template <> +void SymbolScanner::pre(ast::Prompt const & node) { + auto scope = this->scope(); + auto name = node.data.name.data.name; + this->scopes.push_back(name); + auto alias = this->scope(); + PromptSymbol sym(node, alias); + if (!this->driver.tables.symbols.emplace(alias, sym).second) { + this->driver.emit_error("Already have a object with name " + name + " in scope " + scope + ".", node.data.name.location); + return; + } +} + +template <> +void SymbolScanner::post(ast::Prompt const &) { + this->scopes.pop_back(); +} + +} diff --git a/libs/autocog/compiler/stl/symbol-scanner.hxx b/libs/autocog/compiler/stl/symbol-scanner.hxx new file mode 100644 index 0000000..835f975 --- /dev/null +++ b/libs/autocog/compiler/stl/symbol-scanner.hxx @@ -0,0 +1,72 @@ +#ifndef AUTOCOG_COMPILER_STL_SYMBOL_SCANNER_HXX +#define AUTOCOG_COMPILER_STL_SYMBOL_SCANNER_HXX + +#include "autocog/compiler/stl/symbol-table.hxx" + +namespace autocog::compiler::stl { + +class Driver; + +class SymbolScanner { + private: + Driver & driver; + + std::optional fileid; + std::vector scopes; + bool shortcut_flag; + + std::string scope() const; + + private: + // TODO method for `this->driver.tables.symbols.emplace` that check if it already exist. + + public: + SymbolScanner(Driver & driver_); + + public: + template + bool shortcut(ast::Node const &) const { + return shortcut_flag; + } + + template + void pre(ast::Node const &) { + shortcut_flag = true; + } + + template + void post(ast::Node const &) { + shortcut_flag = false; + } +}; + +template <> +void SymbolScanner::pre(ast::Program const &); + +template <> +void SymbolScanner::post(ast::Program const &); + +template <> +void SymbolScanner::pre(ast::Import const &); + +template <> +void SymbolScanner::pre(ast::Alias const &); + +template <> +void SymbolScanner::pre(ast::Define const &); + +template <> +void SymbolScanner::pre(ast::Record const &); + +template <> +void SymbolScanner::post(ast::Record const &); + +template <> +void SymbolScanner::pre(ast::Prompt const &); + +template <> +void SymbolScanner::post(ast::Prompt const &); + +} + +#endif // AUTOCOG_COMPILER_STL_SYMBOL_SCANNER_HXX diff --git a/libs/autocog/compiler/stl/symbol-table.cxx b/libs/autocog/compiler/stl/symbol-table.cxx new file mode 100644 index 0000000..61802a3 --- /dev/null +++ b/libs/autocog/compiler/stl/symbol-table.cxx @@ -0,0 +1,67 @@ + +#include "autocog/compiler/stl/symbol-table.hxx" +#include "autocog/compiler/stl/driver.hxx" +#include "autocog/compiler/stl/evaluate.hxx" + +#include +#include + +#if VERBOSE +# include +#endif + +namespace autocog::compiler::stl { + +void SymbolTable::dump(std::ostream & os) const { + os << " === Symbols ===\n"; + for (auto const & [name, sym] : symbols) { + os << " " << name << " -> "; + std::visit([&os](auto const & s) { + using T = std::decay_t; + if constexpr (std::is_same_v) { + os << "Define"; + } else if constexpr (std::is_same_v) { + os << "Record"; + } else if constexpr (std::is_same_v) { + os << "Prompt"; + } else if constexpr (std::is_same_v) { + os << "Python(" << s.filename << ")"; + } else if constexpr (std::is_same_v) { + os << "UnresolvedImport(fid=" << s.fileid << ")"; + } else if constexpr (std::is_same_v) { + os << "UnresolvedAlias(fid=" << s.fileid << ")"; + } + }, sym); + os << "\n"; + } + + os << " === Contexts ===\n"; + for (auto const & [name, context] : contexts) { + os << " " << name << ":\n"; + for (auto const & [var, val] : context) { + os << " " << var << " = "; + std::visit([&os](auto const & v) { + using T = std::decay_t; + if constexpr (std::is_same_v) { + os << v; + } else if constexpr (std::is_same_v) { + os << v; + } else if constexpr (std::is_same_v) { + os << (v ? "true" : "false"); + } else if constexpr (std::is_same_v) { + os << "\"" << v << "\""; + } else if constexpr (std::is_same_v) { + os << "null"; + } + }, val); + os << "\n"; + } + } + +// os << " === Exports ===\n"; +// for (auto const & [from, to] : exports) { +// os << " " << from << " -> " << to << "\n"; +// } +} + +} diff --git a/libs/autocog/compiler/stl/symbol-table.hxx b/libs/autocog/compiler/stl/symbol-table.hxx new file mode 100644 index 0000000..a218dce --- /dev/null +++ b/libs/autocog/compiler/stl/symbol-table.hxx @@ -0,0 +1,21 @@ +#ifndef AUTOCOG_COMPILER_STL_SYMBOL_TABLE_HXX +#define AUTOCOG_COMPILER_STL_SYMBOL_TABLE_HXX + +#include "autocog/compiler/stl/symbols.hxx" +#include "autocog/compiler/stl/ir.hxx" + +#include + +namespace autocog::compiler::stl { + +struct SymbolTable { + std::unordered_map symbols; + std::unordered_map contexts; +//std::unordered_map exports; + + void dump(std::ostream & os) const; +}; + +} + +#endif // AUTOCOG_COMPILER_STL_SYMBOL_TABLE_HXX diff --git a/libs/autocog/compiler/stl/symbols.cxx b/libs/autocog/compiler/stl/symbols.cxx new file mode 100644 index 0000000..8e7823e --- /dev/null +++ b/libs/autocog/compiler/stl/symbols.cxx @@ -0,0 +1,23 @@ + +#include "autocog/compiler/stl/symbols.hxx" + +#include + +#if VERBOSE +# include +#endif + +namespace autocog::compiler::stl { + +PythonSymbol::PythonSymbol( + std::string const & filename_, + std::string const & alias_, + ast::ObjectRef const & target_ +) : + filename(filename_), + alias(alias_), + target(target_) +{} + +} + diff --git a/libs/autocog/compiler/stl/symbols.hxx b/libs/autocog/compiler/stl/symbols.hxx new file mode 100644 index 0000000..1010963 --- /dev/null +++ b/libs/autocog/compiler/stl/symbols.hxx @@ -0,0 +1,59 @@ +#ifndef AUTOCOG_COMPILER_STL_SYMBOLS_HXX +#define AUTOCOG_COMPILER_STL_SYMBOLS_HXX + +#include "autocog/compiler/stl/ast.hxx" + +namespace autocog::compiler::stl { + +template +struct AstSymbol { + AstNode const & node; + std::string alias; + + AstSymbol( + AstNode const & node_, + std::string const & alias_ + ); +}; + +using DefineSymbol = AstSymbol; +using RecordSymbol = AstSymbol; +using PromptSymbol = AstSymbol; + +struct PythonSymbol { + std::string filename; + std::string alias; + ast::ObjectRef const & target; + + PythonSymbol( + std::string const & filename_, + std::string const & alias_, + ast::ObjectRef const & target_ + ); +}; + +template +struct UnresolvedSymbol { + int fileid; + std::string alias; + ast::ObjectRef const & target; + AstNode const & node; + + UnresolvedSymbol( + int fileid_, + std::string const & alias_, + ast::ObjectRef const & target_, + AstNode const & node_ + ); +}; + +using UnresolvedImport = UnresolvedSymbol; +using UnresolvedAlias = UnresolvedSymbol; + +using AnySymbol = std::variant; + +} + +#include "autocog/compiler/stl/symbols.txx" + +#endif // AUTOCOG_COMPILER_STL_SYMBOLS_HXX diff --git a/libs/autocog/compiler/stl/symbols.txx b/libs/autocog/compiler/stl/symbols.txx new file mode 100644 index 0000000..81ea6c4 --- /dev/null +++ b/libs/autocog/compiler/stl/symbols.txx @@ -0,0 +1,27 @@ + +namespace autocog::compiler::stl { + +template +AstSymbol::AstSymbol( + AstNode const & node_, + std::string const & alias_ +) : + node(node_), + alias(alias_) +{} + +template +UnresolvedSymbol::UnresolvedSymbol( + int fileid_, + std::string const & alias_, + ast::ObjectRef const & target_, + AstNode const & node_ +) : + fileid(fileid_), + alias(alias_), + target(target_), + node(node_) +{} + +} + diff --git a/libs/autocog/compiler/stl/token.cxx b/libs/autocog/compiler/stl/token.cxx new file mode 100644 index 0000000..3205144 --- /dev/null +++ b/libs/autocog/compiler/stl/token.cxx @@ -0,0 +1,64 @@ + +#include "autocog/compiler/stl/token.hxx" + +namespace autocog::compiler::stl { + +const char * token_type_name(TokenType type) { + switch (type) { + case TokenType::NOT_A_VALID_TOKEN: return "not-a-valid-token"; + case TokenType::DEFINE: return "define"; + case TokenType::ARGUMENT: return "argument"; + case TokenType::RECORD: return "record"; + case TokenType::IMPORT: return "import"; + case TokenType::EXPORT: return "export"; + case TokenType::ALIAS: return "alias"; + case TokenType::PROMPT: return "prompt"; + case TokenType::CHANNEL: return "channel"; + case TokenType::FLOW: return "flow"; + case TokenType::RETURN: return "return"; + case TokenType::ANNOTATE: return "annotate"; + case TokenType::TO: return "to"; + case TokenType::FROM: return "from"; + case TokenType::CALL: return "call"; + case TokenType::MAPPED: return "mapped"; + case TokenType::RAVEL: return "ravel"; + case TokenType::BIND: return "bind"; + case TokenType::WRAP: return "wrap"; + case TokenType::PRUNE: return "prune"; + case TokenType::AS: return "as"; + case TokenType::IS: return "is"; + case TokenType::GET: return "get"; + case TokenType::USE: return "use"; + case TokenType::SEARCH: return "search"; + case TokenType::TEXT: return "text"; + case TokenType::SELECT: return "select"; + case TokenType::REPEAT: return "repeat"; + case TokenType::ENUM: return "enum"; + case TokenType::IDENTIFIER: return "identifier"; + case TokenType::STRING_LITERAL: return "string"; + case TokenType::INTEGER_LITERAL: return "integer literal"; + case TokenType::FLOAT_LITERAL: return "float literal"; + case TokenType::LBRACE: return "'{'"; + case TokenType::RBRACE: return "'}'"; + case TokenType::LSQUARE: return "'['"; + case TokenType::RSQUARE: return "']'"; + case TokenType::LPAREN: return "'('"; + case TokenType::RPAREN: return "')'"; + case TokenType::SEMICOLON: return "';'"; + case TokenType::COLON: return "':'"; + case TokenType::COMMA: return "','"; + case TokenType::DOT: return "'.'"; + case TokenType::EQUAL: return "'='"; + case TokenType::PLUS: return "'+'"; + case TokenType::MINUS: return "'-'"; + case TokenType::STAR: return "'*'"; + case TokenType::SLASH: return "'/'"; + case TokenType::LT: return "'<'"; + case TokenType::GT: return "'>'"; + case TokenType::ERROR: return "invalid token"; + case TokenType::END_OF_FILE: return "end of file"; + default: return "unknown"; + } +} + +} diff --git a/libs/autocog/compiler/stl/token.hxx b/libs/autocog/compiler/stl/token.hxx new file mode 100644 index 0000000..ff5e41e --- /dev/null +++ b/libs/autocog/compiler/stl/token.hxx @@ -0,0 +1,102 @@ +#ifndef AUTOCOG_COMPILER_STL_TOKEN_HXX +#define AUTOCOG_COMPILER_STL_TOKEN_HXX + +#include "autocog/compiler/stl/location.hxx" + +#include +#include + +namespace autocog::compiler::stl { + +// Token types +enum class TokenType : int { + NOT_A_VALID_TOKEN, + // Keywords + DEFINE, + ARGUMENT, + RECORD, + IMPORT, + EXPORT, + ALIAS, + PROMPT, + CHANNEL, + FLOW, + RETURN, + ANNOTATE, + TO, + FROM, + CALL, + MAPPED, + RAVEL, + BIND, + WRAP, + PRUNE, + AS, + IS, + GET, + USE, + SEARCH, + TEXT, + SELECT, + REPEAT, + ENUM, + + // Identifiers and literals + IDENTIFIER, + STRING_LITERAL, + INTEGER_LITERAL, + FLOAT_LITERAL, + BOOLEAN_LITERAL, + + // Operators and delimiters + LBRACE, // { + RBRACE, // } + LSQUARE, // [ + RSQUARE, // ] + LPAREN, // ( + RPAREN, // ) + SEMICOLON, // ; + COLON, // : + COMMA, // , + DOT, // . + EQUAL, // = + PLUS, // + + MINUS, // - + STAR, // * + SLASH, // / + LT, // < + GT, // > + + // Comparison operators + LTEQ, // <= + GTEQ, // >= + EQEQ, // == + BANGEQ, // != + + // Logical operators + AMPAMP, // && + PIPEPIPE, // || + BANG, // ! + + // Other operators + PERCENT, // % + QUESTION, // ? + + // Special + ERROR, + END_OF_FILE +}; + +// Token structure +struct Token { + TokenType type{TokenType::NOT_A_VALID_TOKEN}; + std::string text{""}; + SourceLocation location{-1,-1,-1,0}; +}; + +// Helper function to get token type name +const char* token_type_name(TokenType type); + +} + +#endif // AUTOCOG_COMPILER_STL_TOKEN_HXX diff --git a/libs/autocog/llama/xfta/CMakeLists.txt b/libs/autocog/llama/xfta/CMakeLists.txt new file mode 100644 index 0000000..476dea0 --- /dev/null +++ b/libs/autocog/llama/xfta/CMakeLists.txt @@ -0,0 +1,24 @@ + +add_library(autocog_llama_xfta_lib STATIC + fta.cxx + ftt.cxx + model.cxx + manager.cxx + evaluation.cxx + evaluation-choice.cxx + evaluation-completion.cxx + evaluation-text.cxx +) + +target_include_directories(autocog_llama_xfta_lib PUBLIC + ${PROJECT_SOURCE_DIR}/libs + ${PROJECT_SOURCE_DIR}/vendors/headers + ${LLAMA_INC} +) + +target_link_libraries(autocog_llama_xfta_lib PUBLIC + ${LLAMA_LIB} +) + +set_property(TARGET autocog_llama_xfta_lib PROPERTY POSITION_INDEPENDENT_CODE ON) + diff --git a/libs/autocog/llama/xfta/evaluation-choice.cxx b/libs/autocog/llama/xfta/evaluation-choice.cxx new file mode 100644 index 0000000..9b073a9 --- /dev/null +++ b/libs/autocog/llama/xfta/evaluation-choice.cxx @@ -0,0 +1,109 @@ + +#include "autocog/llama/xfta/evaluation.hxx" +#include "autocog/llama/xfta/model.hxx" +#include "autocog/llama/xfta/fta.hxx" +#include "autocog/llama/xfta/ftt.hxx" + +#include + + +#include +#include +#include +#include + +#if VERBOSE +# include +#endif + +#define DEBUG_Evaluation_evaluate_choice VERBOSE && 0 + +namespace autocog::llama::xfta { + +struct ChoiceResult { + size_t index; + ProbaSequence logprobs; + float proba; + + ChoiceResult( + size_t index_, ProbaSequence logprobs_, float proba_ + ) : + index(index_), logprobs(logprobs_), proba(proba_) + {} +}; + +unsigned Evaluation::evaluate_choice(PathState & state) { +#if DEBUG_Evaluation_evaluate_choice + std::cerr << "Executing Choice #" << state.action << std::endl; +#endif + Choice const & action = this->fta.action(state.action).as(); +#if DEBUG_Evaluation_evaluate_choice + std::cerr << " - name: " << action.name << std::endl; + std::cerr << " - width: " << action.width << std::endl; + std::cerr << " - number of choices: " << action.choices.size() << std::endl; +#endif + + if (action.choices.empty()) { + throw std::runtime_error("Choice action has no choices"); + } + + if (action.successors.size() != action.choices.size()) { + throw std::runtime_error("Choice action must have as many successors as choices"); + } + + unsigned num_token_eval = 0; + std::vector results; + + // Evaluate ALL choices in full + for (size_t idx = 0; idx < action.choices.size(); ++idx) { + auto [model, ctx] = this->restore(state); +#if DEBUG_Evaluation_evaluate_choice + std::cerr << " - Model #" << model.id << std::endl; + std::cerr << " - Context #" << ctx << std::endl; + std::cerr << " - choice[" << idx << "]:" << std::endl; + std::cerr << " - number of tokens: " << action.choices[idx].size() << std::endl; +#endif + + // Save current state to restore after evaluation + TokenSequence saved_tokens = model.get_tokens_const(ctx); + + ProbaSequence logprobs; + num_token_eval += model.eval_sequences(action.choices[idx], logprobs, ctx); + + float proba = 0.; + for (float lpb : logprobs) proba += lpb; + proba = std::exp(-proba/logprobs.size()); +#if DEBUG_Evaluation_evaluate_choice + std::cerr << " - proba: " << proba << std::endl; +#endif + + results.emplace_back(idx, logprobs, proba); + + state.context.reset(); // TODO remove once context saving/restore/rewind is implemented + } + + std::sort(results.begin(), results.end(), + [](const ChoiceResult& a, const ChoiceResult& b) { + return a.proba > b.proba; + }); + + unsigned count = 0; + for (const auto & result : results) { + auto & choice_tokens = action.choices[result.index]; + FTT & child = state.parent.add(action.id, choice_tokens, result.logprobs); + child.pruned = ( count > action.width ) || (count > 0 && result.proba < action.threshold); + if (!child.pruned) { + this->enqueue(action.successors[result.index], child, state); + } + count++; + } + +#if DEBUG_Evaluation_evaluate_choice + std::cerr << " > evaluated: " << num_token_eval << std::endl; +#endif + + return num_token_eval; +} + +} + diff --git a/libs/autocog/llama/xfta/evaluation-completion.cxx b/libs/autocog/llama/xfta/evaluation-completion.cxx new file mode 100644 index 0000000..73f440d --- /dev/null +++ b/libs/autocog/llama/xfta/evaluation-completion.cxx @@ -0,0 +1,314 @@ + +#include "autocog/llama/xfta/evaluation.hxx" +#include "autocog/llama/xfta/model.hxx" +#include "autocog/llama/xfta/fta.hxx" +#include "autocog/llama/xfta/ftt.hxx" + +#include + +#include +#include +#include +#include +#include +#include + +#if VERBOSE +# include +#endif + +#define DEBUG_Evaluation_evaluate_completion VERBOSE && 0 +#define DEBUG_expand_beam VERBOSE && 0 +#define DEBUG_beam_search_step VERBOSE && 0 + +namespace autocog::llama::xfta { + +struct BeamState { + TokenSequence tokens; + ProbaSequence logprobs; + float logprob{0.}; + float repetition_penalty = 1.0f; + float diversity_bonus = 0.0f; + float lookahead_bonus = 0.0f; + bool stopped{false}; + + float proba() const { + return std::exp(-logprob/logprobs.size()); + } + + float score() const { + return (this->proba() + lookahead_bonus + diversity_bonus) / repetition_penalty; + } +}; + +static float calculate_repetition_penalty( + TokenSequence const & tokens, float & penalty, + float const penalty_weight, + size_t const min_length = 3, + size_t const max_window_size = 256, + float const length_weight = 1.0, + float const recency_weight = 1.0 +) { + + size_t window_size = std::min(tokens.size(), max_window_size); + for (size_t i = min_length; i < tokens.size(); ++i) { + size_t best_length = 0; + size_t best_distance = 0; + + size_t search_start = (i >= window_size) ? i - window_size : 0; + + for (size_t j = search_start; j < i; ++j) { + size_t match_length = 0; + + while (j + match_length < i && + i + match_length < tokens.size() && + tokens[j + match_length] == tokens[i + match_length]) { + match_length++; + } + + if (match_length >= min_length && match_length > best_length) { + best_length = match_length; + best_distance = i - j; + } + } + + if (best_length >= min_length) { + // Stronger penalty for: + // - Longer repetitions + // - Recent repetitions (smaller distance) + float length_factor = std::log(1. + length_weight * best_length ); + float recency_factor = std::log(1. + recency_weight * best_distance); + penalty *= (1.0f + penalty_weight * length_factor / recency_factor); + } + } + return penalty; +} + +static float token_sequence_diversity(TokenSequence const & a, TokenSequence const & b) { + // Jaccard distance or edit distance + std::set set_a(a.begin(), a.end()); + std::set set_b(b.begin(), b.end()); + + std::set intersection; + std::set_intersection( + set_a.begin(), set_a.end(), + set_b.begin(), set_b.end(), + std::inserter(intersection, intersection.begin()) + ); + + std::set union_set; + std::set_union( + set_a.begin(), set_a.end(), + set_b.begin(), set_b.end(), + std::inserter(union_set, union_set.begin()) + ); + + return 1.0f - (float)intersection.size() / union_set.size(); +} + +static void calculate_diversity_bonuses(std::vector & beams, float const weight) { + for (size_t i = 0; i < beams.size(); ++i) { + float diversity = 0.0f; + for (size_t j = 0; j < beams.size(); ++j) { + if (i != j) { + diversity += token_sequence_diversity(beams[i].tokens, beams[j].tokens); + } + } + beams[i].diversity_bonus = weight * diversity / (beams.size() - 1); + } +} + +static unsigned expand_beam( + Model & model, + ContextID ctx, + BeamState const & beam, + Completion const & action, + TokenSequence const & base_tokens, + std::vector & beams +) { +#if DEBUG_expand_beam + std::cerr << "expand_beam(...):" << std::endl; + std::cerr << " - base_tokens.size() = " << base_tokens.size() << std::endl; + std::cerr << " - beam.tokens.size() = " << beam.tokens.size() << std::endl; +#endif + TokenSequence context_tokens = base_tokens; + context_tokens.insert(context_tokens.end(), beam.tokens.begin(), beam.tokens.end()); + model.set_tokens(context_tokens, ctx); + + std::vector topk_tokens; + std::vector topk_logits; + unsigned num_token_eval = model.eval_topk_tokens( + action.vocab.mask, + action.beams, + topk_tokens, + topk_logits, + ctx + ); + + for (size_t i = 0; i < topk_tokens.size(); ++i) { + BeamState & new_beam = beams.emplace_back(beam); + + new_beam.tokens.push_back(topk_tokens[i]); + new_beam.logprobs.push_back(topk_logits[i]); + new_beam.logprob += topk_logits[i]; + + if (action.repetition) { + TokenSequence beam_tokens; + beam_tokens.insert(beam_tokens.end(), base_tokens.begin(), base_tokens.end()); + beam_tokens.insert(beam_tokens.end(), new_beam.tokens.begin(), new_beam.tokens.end()); + calculate_repetition_penalty( + beam_tokens, new_beam.repetition_penalty, + action.repetition.value() + ); + } + + new_beam.stopped = ( + action.stop.size() <= new_beam.tokens.size() + ) && std::equal( + action.stop.begin(), + action.stop.end(), + new_beam.tokens.end() - action.stop.size() + ); + if (new_beam.stopped) { + new_beam.tokens.erase(new_beam.tokens.end() - action.stop.size(), new_beam.tokens.end()); + } + } + return num_token_eval; +} + +// Prune beams to keep only top k +static void prune_beams( + std::vector & beams, + unsigned beam_width +) { + // Sort by score + std::sort(beams.begin(), beams.end(), [](BeamState const & a, BeamState const & b) { + if (a.stopped != b.stopped) return a.stopped; + return a.score() > b.score(); + }); + + // Keep top beams + std::vector pruned; + size_t kept_active = 0; + for (BeamState const & beam : beams) { + if (beam.stopped) { + pruned.push_back(beam); + } else if (kept_active < beam_width) { + pruned.push_back(beam); + kept_active++; + } + } + + beams = std::move(pruned); +} + +// Run beam search for one position +static bool beam_search_step( + Model & model, + ContextID ctx, + Completion const & action, + TokenSequence const & base_tokens, + std::vector & current_beams, + unsigned & num_token_eval +) { +#if DEBUG_beam_search_step + std::cerr << "beam_search_step(...):" << std::endl; + std::cerr << " - current_beams.size() = " << current_beams.size() << std::endl; +#endif + std::vector next_beams; + + for (BeamState const & beam : current_beams) { + if (beam.stopped) { + next_beams.push_back(beam); + } else { + num_token_eval += expand_beam(model, ctx, beam, action, base_tokens, next_beams); + } + } + if (action.diversity) { + calculate_diversity_bonuses(next_beams, action.diversity.value()); + } + + // Check for early termination + bool all_stopped = std::all_of(next_beams.begin(), next_beams.end(), [](BeamState const & b) { return b.stopped; }); + + if (all_stopped) { + current_beams = std::move(next_beams); + return true; // Signal early termination + } + + // Prune and update beams + prune_beams(next_beams, action.beams); + + if (next_beams.empty()) { + throw std::runtime_error("No valid beams remaining in completion"); + } + + current_beams = std::move(next_beams); + return false; // Continue beam search +} + +unsigned Evaluation::evaluate_completion(PathState & state) { +#if DEBUG_Evaluation_evaluate_completion + std::cerr << "Executing Completion #" << state.action << std::endl; +#endif + Completion const & action = this->fta.action(state.action).as(); +#if DEBUG_Evaluation_evaluate_completion + std::cerr << " - name: " << action.name << std::endl; + std::cerr << " - beams: " << action.beams << std::endl; + std::cerr << " - length: " << action.length << std::endl; + std::cerr << " - ahead: " << action.ahead << std::endl; + std::cerr << " - width: " << action.width << std::endl; + std::cerr << " - threshold: " << action.threshold << std::endl; +#endif + auto [model, ctx] = this->restore(state); +#if DEBUG_Evaluation_evaluate_completion + std::cerr << " - Model #" << model.id << std::endl; + std::cerr << " - Context #" << ctx << std::endl; +#endif + + std::vector beams; + beams.emplace_back(); + + unsigned num_token_eval = 0; + for (unsigned pos = 0; pos < action.length; ++pos) { +#if DEBUG_Evaluation_evaluate_completion + std::cerr << " - pos[" << pos << "]..." << std::endl; +#endif + bool should_stop = beam_search_step( + model, ctx, action, state.tokens, beams, num_token_eval + ); + + if (should_stop) { + break; + } + } + + std::sort(beams.begin(), beams.end(), [](BeamState const & a, BeamState const & b) { + return a.score() > b.score(); + }); + + unsigned count = 0; + for (auto & beam: beams) { +#if DEBUG_Evaluation_evaluate_completion + std::cerr << " - beam[" << count << "]:" << std::endl; + std::cerr << " length: " << beam.tokens.size() << std::endl; + std::cerr << " logprob: " << beam.logprob << std::endl; + std::cerr << " proba: " << beam.proba() << std::endl; +#endif + FTT & child = state.parent.add(action.id, beam.tokens, beam.logprobs); + child.pruned = ( count > action.width ) || (count > 0 && beam.proba() < action.threshold); + if (!child.pruned) { + this->enqueue(action.successors[0], child, state); + } + count++; + } + +#if DEBUG_Evaluation_evaluate_completion + std::cerr << " > evaluated: " << num_token_eval << std::endl; +#endif + + return num_token_eval; +} + +} + diff --git a/libs/autocog/llama/xfta/evaluation-text.cxx b/libs/autocog/llama/xfta/evaluation-text.cxx new file mode 100644 index 0000000..b4da07d --- /dev/null +++ b/libs/autocog/llama/xfta/evaluation-text.cxx @@ -0,0 +1,53 @@ + +#include "autocog/llama/xfta/manager.hxx" +#include "autocog/llama/xfta/evaluation.hxx" +#include "autocog/llama/xfta/model.hxx" +#include "autocog/llama/xfta/fta.hxx" +#include "autocog/llama/xfta/ftt.hxx" + +#include + +#if VERBOSE +# include +#endif + +#define DEBUG_Evaluation_evaluate_text VERBOSE && 0 + +namespace autocog::llama::xfta { + +unsigned Evaluation::evaluate_text(PathState & state) { +#if DEBUG_Evaluation_evaluate_text + std::cerr << "Executing Text #" << state.action << std::endl; +#endif + Text const & action = this->fta.action(state.action).as(); +#if DEBUG_Evaluation_evaluate_text + std::cerr << " - name: " << action.name << std::endl; + std::cerr << " - number of tokens: " << action.tokens.size() << std::endl; +#endif + + unsigned num_token_eval = 0; + ProbaSequence logprobs(action.tokens.size(), 0.); + if (action.evaluate) { + auto [model,ctx] = this->restore(state); +#if DEBUG_Evaluation_evaluate_text + std::cerr << " - Model #" << model.id << std::endl; + std::cerr << " - Context #" << ctx << std::endl; +#endif + num_token_eval += model.eval_sequences(action.tokens, logprobs, ctx); + } +#if DEBUG_Evaluation_evaluate_text + std::cerr << " > evaluated: " << num_token_eval << std::endl; +#endif + + auto & child = state.parent.add(action.id, action.tokens, logprobs); + if (action.successors.size() == 1) { + this->enqueue(action.successors[0], child, state); + } else if (action.successors.size() > 1) { + throw std::runtime_error("Text action should never have more than 1 successor."); + } + + return num_token_eval; +} + +} + diff --git a/libs/autocog/llama/xfta/evaluation.cxx b/libs/autocog/llama/xfta/evaluation.cxx new file mode 100644 index 0000000..45ac59c --- /dev/null +++ b/libs/autocog/llama/xfta/evaluation.cxx @@ -0,0 +1,121 @@ + +#include "autocog/llama/xfta/manager.hxx" +#include "autocog/llama/xfta/evaluation.hxx" +#include "autocog/llama/xfta/model.hxx" +#include "autocog/llama/xfta/ftt.hxx" +#include "autocog/llama/xfta/fta.hxx" + +#if VERBOSE +# include +#endif + +#define DEBUG_Evaluation_enqueue VERBOSE && 0 +#define DEBUG_Evaluation_advance VERBOSE && 0 + +namespace autocog::llama::xfta { + +PathState::PathState(ActionID const action_, FTT & parent_, TokenSequence const & tokens_, std::optional context_) : + action(action_), + parent(parent_), + tokens(tokens_), + context(context_) +{} + +float PathState::proba() const { + return this->parent.proba(); +} + +Evaluation::Evaluation(EvaluationConfig const & config_, ModelID const model_, FTA const & fta_) : + config(config_), + model(model_), + fta(fta_), + queue(), + root(nullptr) +{} + +Evaluation::~Evaluation() { + if (this->root) + delete this->root; +} + +unsigned Evaluation::advance(std::optional max_token_eval) { + if (this->root == nullptr) this->initial(); + + unsigned num_action_eval = 0; + unsigned num_token_eval = 0; + while (!queue.empty() && (max_token_eval == std::nullopt || num_token_eval < max_token_eval)) { + PathState & state = queue.front(); +#if VERBOSE + std::cerr << "Evaluation::advance [ Q=" << queue.size() << ", A=" << num_action_eval << ", T=" << num_token_eval << " ]" << std::endl; +#endif +#if DEBUG_Evaluation_advance + std::cerr << " state.action = " << state.action << std::endl; + std::cerr << " state.tokens.size() = " << state.tokens.size() << std::endl; + std::cerr << " state.proba() = " << state.proba() << std::endl; + std::cerr << " state.parent.length = " << state.parent.length << std::endl; + std::cerr << " state.parent.logprob = " << state.parent.logprob << std::endl; +#endif + Action const & action = this->fta.action(state.action); + switch (action.kind) { + case ActionKind::Text: + num_token_eval += this->evaluate_text(state); + break; + case ActionKind::Completion: + num_token_eval += this->evaluate_completion(state); + break; + case ActionKind::Choice: + num_token_eval += this->evaluate_choice(state); + break; + } + queue.pop(); + num_action_eval++; + } + + return num_token_eval; +} + +FTT const & Evaluation::retrieve() const { + return *(this->root); +} + +void Evaluation::initial() { + Text const & init = this->fta.action(0).as(); + TokenSequence tokens = init.tokens; + TokenID bos = Manager::get_model(this->model).bos_token(); + if (tokens[0] != bos) { + tokens.insert(tokens.begin(), bos); + } + this->root = FTT::make_root(init.tokens); + this->queue.emplace(init.successors[0], *(this->root), init.tokens, std::nullopt); +} + +void Evaluation::enqueue( + ActionID const action, + FTT & parent, + PathState const & state +) { +#if DEBUG_Evaluation_enqueue + std::cerr << ">> Evaluation::enqueue <<" << std::endl; +#endif + std::optional ctx = state.context; + ctx.reset(); // TODO context saving logic + + std::vector tokens(state.tokens.begin(), state.tokens.end()); + tokens.insert(tokens.end(), parent.tokens.begin(), parent.tokens.end()); + + this->queue.emplace(action, parent, tokens, ctx); +} + +std::pair Evaluation::restore(PathState & state) const { + Model & model = Manager::get_model(this->model); + + if (!state.context) { + state.context = 0; // TODO look at existing context for the largest prefix? + } + model.set_tokens(state.tokens, state.context.value()); + + return std::pair(model, state.context.value()); +} + +} + diff --git a/libs/autocog/llama/xfta/evaluation.hxx b/libs/autocog/llama/xfta/evaluation.hxx new file mode 100644 index 0000000..1597e8b --- /dev/null +++ b/libs/autocog/llama/xfta/evaluation.hxx @@ -0,0 +1,64 @@ +#ifndef AUTOCOG_LLAMA_XFTA_EVALUATION_HXX +#define AUTOCOG_LLAMA_XFTA_EVALUATION_HXX + +#include "autocog/llama/xfta/types.hxx" + +#include +#include + +namespace autocog::llama::xfta { + +class Model; +class FTA; +class FTT; +class Text; +class Completion; +class Choice; + +struct PathState { + ActionID const action; //< Action to be evaluated next + FTT & parent; //< Previous FTT in the path, results of exploring this state will be added to that tree + TokenSequence const tokens; //< Tokens that lead to this state + std::optional context; //< Context used to evaluate this path + + PathState(ActionID const action_, FTT & parent, std::vector const & tokens_, std::optional context); + float proba() const; +}; + +struct EvaluationConfig { + bool evaluate_text{true}; +}; + +class Evaluation { + public: + using Queue = std::queue; + EvaluationConfig const config; + + private: + ModelID const model; + FTA const & fta; + + Queue queue; + FTT * root; + + protected: + std::pair restore(PathState & state) const; + + void initial(); + void enqueue(ActionID const action, FTT & parent, PathState const & current); + + unsigned evaluate_text (PathState & state); + unsigned evaluate_completion (PathState & state); + unsigned evaluate_choice (PathState & state); + + public: + Evaluation(EvaluationConfig const & config_, ModelID const model_, FTA const & fta_); + ~Evaluation(); + unsigned advance(std::optional max_token_eval); + FTT const & retrieve() const; +}; + +} + +#endif /* AUTOCOG_LLAMA_XFTA_EVALUATION_HXX */ + diff --git a/libs/autocog/llama/xfta/fta.cxx b/libs/autocog/llama/xfta/fta.cxx new file mode 100644 index 0000000..eff305c --- /dev/null +++ b/libs/autocog/llama/xfta/fta.cxx @@ -0,0 +1,71 @@ + +#include "autocog/llama/xfta/fta.hxx" +#include "autocog/llama/xfta/model.hxx" + +#include +#include + +namespace autocog::llama::xfta { + +Action::Action( + ActionKind const kind_, + ActionID const id_, + std::string const & name_ +) : + kind(kind_), + id(id_), + name(name_), + successors() +{} + +Text::Text( + ActionID const id_, + std::string const & name_, + bool const eval +) : + Action(ActionKind::Text, id_, name_), + evaluate(eval) +{} + +Completion::Completion( + ActionID const id_, + std::string const & name_, + float threshold_, + unsigned length_, + unsigned beams_, + unsigned ahead_, + unsigned width_, + std::optional repetition_, + std::optional diversity_ +) : + Action(ActionKind::Completion, id_, name_), + threshold(threshold_), + length(length_), + beams(beams_), + ahead(ahead_), + width(width_), + repetition(repetition_), + diversity(diversity_) +{} + +Choice::Choice( + ActionID const id_, + std::string const & name_, + float threshold_, + unsigned width_ +) : + Action(ActionKind::Choice, id_, name_), + threshold(threshold_), + width(width_) +{} + +Action const & FTA::action(ActionID const id) const { + Action const & action = *(this->actions.at(id)); + if (action.id != id) { + throw std::runtime_error("Action's ID does not match position in FTA::actions!"); + } + return action; +} + +} + diff --git a/libs/autocog/llama/xfta/fta.hxx b/libs/autocog/llama/xfta/fta.hxx new file mode 100644 index 0000000..7e46c05 --- /dev/null +++ b/libs/autocog/llama/xfta/fta.hxx @@ -0,0 +1,99 @@ +#ifndef AUTOCOG_LLAMA_XFTA_FTA_HXX +#define AUTOCOG_LLAMA_XFTA_FTA_HXX + +#include "autocog/llama/xfta/types.hxx" + +#include +#include +#include +#include +#include + +namespace autocog::llama::xfta { + +class Model; + +struct Vocab { + std::vector mask; + + void topk(unsigned k, std::vector const & input, std::vector> const & output) const; +}; + +enum class ActionKind { + Text, + Completion, + Choice +}; + +struct Action { + ActionKind const kind; + ActionID const id; + std::string const name; + + std::vector successors; + + Action(ActionKind const kind_, ActionID const id_, std::string const & name_); + + template + T const & as() const { + if (T::Kind != kind) { + throw std::runtime_error("Calling Action::as() with uncompatible ActionKind."); + } + return static_cast(*this); + } +}; + +struct Text : public Action { + static constexpr ActionKind Kind = ActionKind::Text; + + bool const evaluate; //< whether to evaluate the probability using the model (else p=1.) + + TokenSequence tokens; + + Text(ActionID const id_, std::string const & name_, bool const eval); +}; + +struct Completion : public Action { + static constexpr ActionKind Kind = ActionKind::Completion; + + float const threshold; //< Probability threshold for pruning + unsigned const length; //< Maximum length of the completion + unsigned const beams; //< Number of concurrent exploration beams + unsigned const ahead; //< Look ahead parameter for beam search + unsigned const width; //< Maximum number of beams to select + + std::optional const repetition; //< Penalize repeting pattern + std::optional const diversity; //< Encourage diversity across beams + + Vocab vocab; + TokenSequence stop; + + Completion( + ActionID const id_, std::string const & name_, + float threshold_, unsigned length_, + unsigned beams_, unsigned ahead_, unsigned width_, + std::optional repetition_, + std::optional diversity_ + ); +}; + +struct Choice : public Action { + static constexpr ActionKind Kind = ActionKind::Choice; + + float const threshold; //< Probability threshold for pruning + unsigned const width; // Maximum number of choices to select + + std::vector choices; // Each choice is a token sequence + + Choice(ActionID const id_, std::string const & name_, float threshold_, unsigned width_); +}; + +struct FTA { + Action const & action(ActionID const id) const; + std::vector> actions; +}; + +} + +#endif /* AUTOCOG_LLAMA_XFTA_FTA_HXX */ + diff --git a/libs/autocog/llama/xfta/ftt.cxx b/libs/autocog/llama/xfta/ftt.cxx new file mode 100644 index 0000000..375bd03 --- /dev/null +++ b/libs/autocog/llama/xfta/ftt.cxx @@ -0,0 +1,58 @@ + +#include "autocog/llama/xfta/ftt.hxx" + +#include + +#if VERBOSE +# include +#endif + +#define DEBUG_FTT_add VERBOSE && 0 + +namespace autocog::llama::xfta { + +FTT::FTT( + ActionID const action_, + TokenSequence const & tokens_, + ProbaSequence const & logprobs_, + float logprob_, + unsigned length_ +) : + action(action_), + tokens(tokens_), + logprobs(logprobs_), + logprob(logprob_), + length(length_), + pruned(false), + children() +{} + +FTT & FTT::add( + ActionID const action_, + TokenSequence const & tokens_, + ProbaSequence const & logprobs_ +) { +#if DEBUG_FTT_add + std::cerr << ">> FTT::add <<" << std::endl; +#endif + float logprob_ = this->logprob; + for (auto lpb: logprobs_) logprob_ += lpb; + this->children.emplace_back(action_, tokens_, logprobs_, logprob_, this->length + tokens_.size()); + return this->children.back(); +} + +FTT * FTT::make_root(TokenSequence const & tokens) { + ProbaSequence logprobs(tokens.size(), 0.); + return new FTT(0, tokens, logprobs, 0, tokens.size()); +} + +float FTT::proba() const { + return std::exp(-this->logprob / this->length); +} + +std::list const & FTT::get_children() const { + return this->children; +} + +} + diff --git a/libs/autocog/llama/xfta/ftt.hxx b/libs/autocog/llama/xfta/ftt.hxx new file mode 100644 index 0000000..e8ef158 --- /dev/null +++ b/libs/autocog/llama/xfta/ftt.hxx @@ -0,0 +1,40 @@ +#ifndef AUTOCOG_LLAMA_XFTA_FTT_HXX +#define AUTOCOG_LLAMA_XFTA_FTT_HXX + +#include "autocog/llama/xfta/types.hxx" + +#include + +namespace autocog::llama::xfta { + +class FTT { + public: + ActionID const action; //< Action evaluated for this node + TokenSequence const tokens; //< Tokens generated at this node + ProbaSequence const logprobs; //< Logprob for each token + + float const logprob; //< Cumulative logprob from the root + unsigned const length; //< Total length from the root + + bool pruned{false}; + private: + std::list children; + + public: + /// I'd like that constructor to be private but it prevents the use of `emplace_back` in `add`. + /// Adding `friend class std::list;` does not solve the issue... + FTT(ActionID const action_, TokenSequence const & tokens_, ProbaSequence const & logprobs_, float logprob_, unsigned length_); + + public: + static FTT * make_root(TokenSequence const & tokens_); + FTT & add(ActionID const action_, TokenSequence const & tokens_, ProbaSequence const & logprobs_); + + float proba() const; + + std::list const & get_children() const; +}; + +} + +#endif /* AUTOCOG_LLAMA_XFTA_FTT_HXX */ + diff --git a/libs/autocog/llama/xfta/manager.cxx b/libs/autocog/llama/xfta/manager.cxx new file mode 100644 index 0000000..1f64eee --- /dev/null +++ b/libs/autocog/llama/xfta/manager.cxx @@ -0,0 +1,107 @@ + +#include "autocog/llama/xfta/manager.hxx" +#include "autocog/llama/xfta/evaluation.hxx" + +#include +#include +#include + +#if VERBOSE +# include +#endif + +namespace autocog::llama::xfta { + +void quiet_log_callback(enum ggml_log_level level, const char * text, [[maybe_unused]] void * user_data) { + if (level == GGML_LOG_LEVEL_ERROR) { + fprintf(stderr, "%s", text); + } +} + +bool Manager::initialized{false}; + +Manager & Manager::instance() { + static Manager __instance; + return __instance; +} + +Manager::~Manager() { + cleanup(); +} + +void Manager::cleanup() { +#if VERBOSE + std::cerr << "Manager::cleanup()" << std::endl; +#endif + if (Manager::initialized) { + evaluations.clear(); + models.clear(); + llama_backend_free(); + Manager::initialized = false; + } +} + +void Manager::initialize() { +#if VERBOSE + std::cerr << "Manager::initialize()" << std::endl; +#endif +#if VERBOSE == 0 + llama_log_set(quiet_log_callback, nullptr); +#endif + llama_backend_init(); + + auto & manager = instance(); + manager.models.emplace_back(); // adding Model #0 which is a simple character level random number generator (for ultra-fast testing) + + std::atexit([]() { + instance().cleanup(); + }); + Manager::initialized = true; +} + +ModelID Manager::add_model(std::string const & path, int n_ctx) { + auto & manager = instance(); + ModelID id = manager.models.size(); + manager.models.emplace_back(id, path, n_ctx); + return id; +} + +Model & Manager::get_model(ModelID id) { + auto & manager = instance(); + return manager.models[id]; +} + +EvalID Manager::add_eval(ModelID const model, FTA const & fta) { + auto & manager = instance(); + EvalID id = manager.next_eval_id++; + EvaluationConfig config; + manager.evaluations.try_emplace(id, config, model, fta); + return id; +} + +Evaluation & Manager::get_eval(EvalID id) { + auto & manager = instance(); + auto it = manager.evaluations.find(id); + if (it == manager.evaluations.end()) { + throw std::runtime_error("Invalid Evaluation ID: " + std::to_string(id)); + } + return it->second; +} + +unsigned Manager::advance(EvalID id, std::optional max_token_eval) { + auto & eval = get_eval(id); + return eval.advance(max_token_eval); +} + +FTT const & Manager::retrieve(EvalID id) { + auto & eval = get_eval(id); + return eval.retrieve(); +} + +void Manager::rm_eval(EvalID id) { + auto & manager = instance(); + manager.evaluations.erase(id); +} + +} + diff --git a/libs/autocog/llama/xfta/manager.hxx b/libs/autocog/llama/xfta/manager.hxx new file mode 100644 index 0000000..46afa71 --- /dev/null +++ b/libs/autocog/llama/xfta/manager.hxx @@ -0,0 +1,48 @@ +#ifndef AUTOCOG_LLAMA_XFTA_MANAGER_HXX +#define AUTOCOG_LLAMA_XFTA_MANAGER_HXX + +#include "autocog/llama/xfta/types.hxx" +#include "autocog/llama/xfta/model.hxx" +#include "autocog/llama/xfta/evaluation.hxx" + +#include +#include +#include + +namespace autocog::llama::xfta { + +class Evaluation; +class FTA; + +class Manager { + public: + static bool initialized; + private: + std::vector models; + + EvalID next_eval_id = 0; + std::unordered_map evaluations; + + Manager() = default; + void cleanup(); + static Manager & instance(); + + public: + ~Manager(); + + static void initialize(); + + static ModelID add_model(std::string const & path, int n_ctx); + static Model & get_model(ModelID id); + + static EvalID add_eval(ModelID const model_, FTA const & fta); + static Evaluation & get_eval(EvalID id); + static unsigned advance(EvalID id, std::optional max_token_eval=std::nullopt); + static FTT const & retrieve(EvalID id); + static void rm_eval(EvalID id); +}; + +} + +#endif /* AUTOCOG_LLAMA_XFTA_MANAGER_HXX */ + diff --git a/libs/autocog/llama/xfta/model.cxx b/libs/autocog/llama/xfta/model.cxx new file mode 100644 index 0000000..649967a --- /dev/null +++ b/libs/autocog/llama/xfta/model.cxx @@ -0,0 +1,382 @@ + +#include "autocog/llama/xfta/model.hxx" + +#include + +#include +#include +#include + +#if VERBOSE +# include +#endif + +#define DEBUG_Model_set_tokens VERBOSE && 0 +#define DEBUG_Model_eval_sequences VERBOSE && 0 +#define DEBUG_Model_eval_topk_tokens VERBOSE && 0 + +namespace autocog::llama::xfta { + +Model::Model() : + id(0), + model(nullptr), + contexts() +{} + +Model::Model(ModelID const id_, std::string const & model_path, int n_ctx) : + id(id_), + model(nullptr), + contexts() +{ + // Load model + llama_model_params model_params = llama_model_default_params(); + this->model = llama_model_load_from_file(model_path.c_str(), model_params); + if (!this->model) { + throw std::runtime_error("Failed to load model from: " + model_path); + } + + // Create context parameters + llama_context_params ctx_params = llama_context_default_params(); + ctx_params.n_ctx = n_ctx; + + // Create single context with ID=0 (and associated token sequence) + llama_context * ctx = llama_init_from_model(this->model, ctx_params); + if (!ctx) { + llama_model_free(this->model); + throw std::runtime_error("Failed to create llama context"); + } + this->contexts.push_back(ctx); + this->tokens.emplace_back(); +} + +Model::~Model() { +#if VERBOSE + std::cerr << "Model::~Model(" << this->id << ")" << std::endl; +#endif + if (this->id == 0) { + // NOP + } else { + // for (auto* ctx : this->contexts) if (ctx) llama_free(ctx); + contexts.clear(); + if (this->model) llama_model_free(model); + } +} + +void Model::check_context_id(ContextID const id) const { + if (this->contexts.size() != this->tokens.size()) { + throw std::runtime_error("Discrepency between contexts and tokens vector size."); + } + if (id >= this->contexts.size()) { + throw std::runtime_error("Invalid context ID: " + std::to_string(id)); + } + if (this->contexts[id] == nullptr) { + throw std::runtime_error("Missing context ID: " + std::to_string(id)); + } +} + +llama_context * Model::get_context(ContextID const id) const { + check_context_id(id); + return this->contexts[id]; +} + +TokenSequence const & Model::get_tokens_const(ContextID const id) const { + check_context_id(id); + return this->tokens[id]; +} + +TokenSequence & Model::get_tokens(ContextID const id) { + check_context_id(id); + return this->tokens[id]; +} + +ContextID Model::fork_context(ContextID const) { + throw std::runtime_error("Context forking is not implemented yet (Phase 3 feature)"); +} + +TokenSequence Model::tokenize(std::string const & text, bool add_bos, bool special) { + if (this->id == 0) { + throw std::runtime_error("Using model #0 (RNG) is not implemented yet!"); + } + std::vector tokens; + tokens.resize(text.length() + (add_bos ? 1 : 0) + 1); // Rough upper bound + + int n_tokens = llama_tokenize( + this->get_vocab(), + text.c_str(), + text.length(), + tokens.data(), + tokens.size(), + add_bos, + special + ); + + if (n_tokens < 0) { + throw std::runtime_error("Tokenization failed for text: " + text); + } + + TokenSequence result(tokens.begin(), tokens.begin() + n_tokens); + return result; +} + +std::string Model::detokenize(TokenSequence const & tokens, bool spec_rm, bool spec_unp) { + if (this->id == 0) { + throw std::runtime_error("Using model #0 (RNG) is not implemented yet!"); + } + + if (tokens.empty()) { + return ""; + } + + // Detokenize + std::string result; + result.resize(tokens.size() * 8); // Rough estimate for buffer size + + int n_chars = llama_detokenize( + this->get_vocab(), + tokens.data(), + tokens.size(), + &result[0], + result.size(), + spec_rm, + spec_unp + ); + + if (n_chars < 0) { + throw std::runtime_error("Detokenization failed"); + } + + result.resize(n_chars); + return result; +} + +const llama_vocab * Model::get_vocab() const { + if (this->id == 0) { + throw std::runtime_error("Using model #0 (RNG) is not implemented yet!"); + } + return llama_model_get_vocab(this->model); +} + +size_t Model::vocab_size() const { + if (this->id == 0) { + throw std::runtime_error("Using model #0 (RNG) is not implemented yet!"); + } + return llama_vocab_n_tokens(this->get_vocab()); +} + +TokenID Model::bos_token() const { + if (this->id == 0) { + throw std::runtime_error("Using model #0 (RNG) is not implemented yet!"); + } + return llama_vocab_bos(this->get_vocab()); +} + +TokenID Model::eos_token() const { + if (this->id == 0) { + throw std::runtime_error("Using model #0 (RNG) is not implemented yet!"); + } + return llama_vocab_eos(this->get_vocab()); +} + +static float logit_to_log_sum_exp(float * logit, unsigned vocab_size) { + // Find max logit for numerical stability + float max_logit = *std::max_element(logit, logit + vocab_size); + + // Compute log-sum-exp for normalization + float log_sum_exp = 0.0f; + for (unsigned i = 0; i < vocab_size; ++i) { + log_sum_exp += std::exp(logit[i] - max_logit); + } + return max_logit + std::log(log_sum_exp); +} + +[[maybe_unused]] static void retrieve_logprobs(llama_context * ctx, unsigned vocab_size, std::vector & logprobs) { + float * logits = llama_get_logits(ctx); + float log_sum_exp = logit_to_log_sum_exp(logits, vocab_size); + logprobs.resize(vocab_size); + for (unsigned tok = 0; tok < vocab_size; tok++) { + logprobs[tok] = log_sum_exp - logits[tok]; + } +} + +static void sample_logprobs(llama_context * ctx, std::vector const & mask, std::vector> & candidates) { + float * logits = llama_get_logits(ctx); + float log_sum_exp = logit_to_log_sum_exp(logits, mask.size()); + + for (unsigned tok = 0; tok < mask.size(); tok++) { + if (mask[tok]) { + candidates.emplace_back(tok, log_sum_exp - logits[tok]); + } + } +} + +static float retrieve_logprob(llama_context * ctx, unsigned vocab_size, TokenID token) { + float * logits = llama_get_logits(ctx); + return logit_to_log_sum_exp(logits, vocab_size) - logits[token]; +} + +static llama_pos find_common_prefix(const TokenSequence& a, const TokenSequence& b) { + llama_pos common = 0; + size_t min_size = std::min(a.size(), b.size()); + while (static_cast(common) < min_size && a[common] == b[common]) { + common++; + } + return common; +} + +unsigned Model::set_tokens(TokenSequence const & target_tokens, ContextID const id) { +#if DEBUG_Model_set_tokens + std::cerr << "Model::set_tokens(...):" << std::endl; + std::cerr << " > target_tokens.size() = " << target_tokens.size() << std::endl; +#endif + if (this->id == 0) { + return target_tokens.size(); + } + check_context_id(id); + + TokenSequence & current_tokens = this->get_tokens(id); +#if DEBUG_Model_set_tokens + std::cerr << " > current_tokens.size() = " << current_tokens.size() << std::endl; +#endif + llama_context * ctx = this->get_context(id); + unsigned n_ctx = llama_n_ctx(ctx); + if (current_tokens.size() > n_ctx) { + throw std::runtime_error("Token sequence too long: " + std::to_string(current_tokens.size()) + " > " + std::to_string(n_ctx)); + } + + llama_memory_t mem = llama_get_memory(ctx); +#if DEBUG_Model_set_tokens + llama_pos kv_pos_max = llama_memory_seq_pos_max(mem, 0); + llama_pos kv_pos_min = llama_memory_seq_pos_min(mem, 0); + std::cerr << " > KV cache pos_min = " << kv_pos_min << std::endl; + std::cerr << " > KV cache pos_max = " << kv_pos_max << std::endl; +#endif + + llama_pos common_prefix = find_common_prefix(current_tokens, target_tokens); +#if DEBUG_Model_set_tokens + std::cerr << " > common_prefix = " << common_prefix << std::endl; +#endif + + unsigned num_token_eval = 0; + if (common_prefix == 0) { + llama_memory_seq_rm(mem, 0, 0, -1); + + llama_batch batch = llama_batch_get_one(const_cast(target_tokens.data()), target_tokens.size()); + if (llama_decode(ctx, batch) != 0) { + throw std::runtime_error("Failed to set the token sequence."); + } + num_token_eval += target_tokens.size(); + } else { + if (static_cast(common_prefix) < current_tokens.size()) { + llama_memory_seq_rm(mem, 0, common_prefix, -1); + } + + if (static_cast(common_prefix) < target_tokens.size()) { + TokenSequence extension(target_tokens.begin() + common_prefix, target_tokens.end()); + + llama_batch batch = llama_batch_get_one(const_cast(extension.data()), extension.size()); + + std::vector positions(extension.size()); + for (size_t i = 0; i < extension.size(); ++i) { + positions[i] = common_prefix + i; + } + batch.pos = positions.data(); + + if (llama_decode(ctx, batch) != 0) { + throw std::runtime_error("Failed to decode token"); + } + num_token_eval += extension.size(); + } + } + current_tokens = target_tokens; + return num_token_eval; +} + +unsigned Model::eval_sequences(TokenSequence const & new_tokens, ProbaSequence & logprobs, ContextID const id) { +#if DEBUG_Model_eval_sequences + std::cerr << "Model::eval_sequences(...):" << std::endl; + std::cerr << " > new_tokens.size() = " << new_tokens.size() << std::endl; +#endif + if (this->id == 0) { + throw std::runtime_error("Using model #0 (RNG) is not implemented yet!"); + // TODO fill `logprobs` with random vales + return new_tokens.size(); + } + + TokenSequence & loc_tokens = this->get_tokens(id); + llama_pos token_pos = loc_tokens.size(); + logprobs.clear(); + + for (auto token: new_tokens) { +#if DEBUG_Model_set_tokens + std::cerr << " > token_pos = " << token_pos << std::endl; +#endif + + llama_batch batch = llama_batch_get_one(&token, 1); + batch.pos = &token_pos; + + if (llama_decode(this->get_context(id), batch) != 0) { + throw std::runtime_error("Failed to decode token"); + } + + logprobs.push_back(retrieve_logprob(this->get_context(id), this->vocab_size(), token)); + + token_pos++; + } + loc_tokens.insert(loc_tokens.end(), new_tokens.begin(), new_tokens.end()); + return new_tokens.size(); +} + + +unsigned Model::eval_topk_tokens( + std::vector const & vocab_mask, + size_t max_candidates, + std::vector & topk_tokens, + std::vector & topk_lobprobs, + ContextID const id +) { +#if DEBUG_Model_eval_topk_tokens + std::cerr << "Model::eval_topk_tokens(...):" << std::endl; + std::cerr << " > max_candidates = " << max_candidates << std::endl; +#endif + + check_context_id(id); + if (this->id == 0) { + throw std::runtime_error("Using model #0 (RNG) is not implemented yet!"); + } + + size_t vocab_size = this->vocab_size(); + if (vocab_mask.size() != vocab_size) { + throw std::runtime_error("vocab_mask size (" + std::to_string(vocab_mask.size()) + ") does not match vocabulary size (" + std::to_string(vocab_size) + ")"); + } + + llama_context * ctx = this->get_context(id); + + topk_tokens.clear(); + topk_lobprobs.clear(); + + std::vector> candidates; + sample_logprobs(ctx, vocab_mask, candidates); + + // Handle edge case: no valid candidates + if (candidates.empty()) { + throw std::runtime_error("Failed to find candidate token. Cannot have an empty vocabularity mask (all false)."); + } + + std::sort(candidates.begin(), candidates.end(), [](const auto& a, const auto& b) { + return a.second < b.second; + }); + + size_t k = std::min(max_candidates, candidates.size()); + topk_tokens.reserve(k); + topk_lobprobs.reserve(k); + + for (size_t i = 0; i < k; ++i) { + topk_tokens.push_back(candidates[i].first); + topk_lobprobs.push_back(candidates[i].second); + } + + return 1; +} + +} + diff --git a/libs/autocog/llama/xfta/model.hxx b/libs/autocog/llama/xfta/model.hxx new file mode 100644 index 0000000..8863dd4 --- /dev/null +++ b/libs/autocog/llama/xfta/model.hxx @@ -0,0 +1,64 @@ +#ifndef AUTOCOG_LLAMA_XFTA_MODEL_HXX +#define AUTOCOG_LLAMA_XFTA_MODEL_HXX + +#include "autocog/llama/xfta/types.hxx" + +#include + +namespace autocog::llama::xfta { + +class Model { + public: + ModelID const id; + + private: + llama_model * model; + std::vector contexts; + std::vector tokens; + + llama_context * get_context(ContextID const id = 0) const; + TokenSequence & get_tokens(ContextID const id = 0); + void check_context_id(ContextID const id = 0) const; + + const llama_vocab * get_vocab() const; + + public: + Model(); + Model(ModelID const id, std::string const & model_path, int n_ctx); + ~Model(); + + ContextID fork_context(ContextID const id = 0); + + TokenSequence tokenize(std::string const & text, bool add_bos, bool special); + std::string detokenize(TokenSequence const & tokens, bool spec_rm, bool spec_unp); + + size_t vocab_size() const; + TokenID bos_token() const; + TokenID eos_token() const; + + TokenSequence const & get_tokens_const(ContextID const id = 0) const; + + unsigned set_tokens( + TokenSequence const & tokens, + ContextID const id = 0 + ); + + unsigned eval_sequences( + TokenSequence const & tokens, + ProbaSequence & logprobs, + ContextID const id = 0 + ); + + unsigned eval_topk_tokens( + std::vector const & vocab_mask, + size_t max_candidates, + std::vector & topk_tokens, + std::vector & topk_logprobs, + ContextID const id + ); +}; + +} + +#endif /* AUTOCOG_LLAMA_XFTA_MODEL_HXX */ + diff --git a/libs/autocog/llama/xfta/types.hxx b/libs/autocog/llama/xfta/types.hxx new file mode 100644 index 0000000..3c92818 --- /dev/null +++ b/libs/autocog/llama/xfta/types.hxx @@ -0,0 +1,25 @@ +#ifndef AUTOCOG_LLAMA_XFTA_TYPES_HXX +#define AUTOCOG_LLAMA_XFTA_TYPES_HXX + +#include +#include + +#ifndef VERBOSE +# define VERBOSE 0 +#endif + +namespace autocog::llama::xfta { + +using EvalID = unsigned; +using ModelID = unsigned; +using ContextID = unsigned; +using ActionID = unsigned; +using TokenID = llama_token; + +using TokenSequence = std::vector; +using ProbaSequence = std::vector; + +} + +#endif /* AUTOCOG_LLAMA_XFTA_TYPES_HXX */ + diff --git a/libs/autocog/utilities/CMakeLists.txt b/libs/autocog/utilities/CMakeLists.txt new file mode 100644 index 0000000..35c02a8 --- /dev/null +++ b/libs/autocog/utilities/CMakeLists.txt @@ -0,0 +1,11 @@ + +add_library(autocog_utilities_lib STATIC + exception.cxx +) + +target_include_directories(autocog_utilities_lib PUBLIC + ${PROJECT_SOURCE_DIR}/libs +) + +set_property(TARGET autocog_utilities_lib PROPERTY POSITION_INDEPENDENT_CODE ON) + diff --git a/libs/autocog/utilities/exception.cxx b/libs/autocog/utilities/exception.cxx new file mode 100644 index 0000000..72881b4 --- /dev/null +++ b/libs/autocog/utilities/exception.cxx @@ -0,0 +1,384 @@ +#include "autocog/utilities/exception.hxx" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace autocog::utilities { + +InternalError::InternalError( + std::string msg +) : + message(std::move(msg)) +{} + +const char * InternalError::what() const noexcept { + return message.c_str(); +} + +thread_local Backtrace g_last_throw_backtrace; + +namespace { + +// Helper to demangle a C++ symbol +std::string demangle(const char* name) { + if (!name) return "???"; + + int status = -1; + std::unique_ptr demangled( + abi::__cxa_demangle(name, nullptr, nullptr, &status), + std::free + ); + + return (status == 0 && demangled) ? std::string(demangled.get()) : std::string(name); +} + +// Helper to run addr2line and get file:line info +struct Addr2LineResult { + std::string function; + std::string file; + int line; + bool success; + + Addr2LineResult() : line(-1), success(false) {} +}; + +Addr2LineResult addr2line(const std::string& executable, void* addr) { + Addr2LineResult result; + + char cmd[512]; + snprintf(cmd, sizeof(cmd), "addr2line -e %s -f -C -i %p 2>/dev/null", + executable.c_str(), addr); + + struct PipeCloser { + void operator()(FILE* f) const { if (f) pclose(f); } + }; + std::unique_ptr pipe(popen(cmd, "r")); + if (!pipe) { + return result; + } + + std::array buffer; + std::vector lines; + + while (fgets(buffer.data(), buffer.size(), pipe.get()) != nullptr) { + std::string line(buffer.data()); + // Remove trailing newline + while (!line.empty() && (line.back() == '\n' || line.back() == '\r')) { + line.pop_back(); + } + lines.push_back(line); + } + + // addr2line output format: + // function_name + // file:line + if (lines.size() >= 2 && lines[0] != "??" && lines[1] != "??:0") { + result.function = lines[0]; + + // Parse file:line + size_t colon_pos = lines[1].rfind(':'); + if (colon_pos != std::string::npos) { + result.file = lines[1].substr(0, colon_pos); + try { + result.line = std::stoi(lines[1].substr(colon_pos + 1)); + result.success = true; + } catch (...) { + // Invalid line number + } + } + } + + return result; +} + +// Parse backtrace symbol and extract components +struct SymbolInfo { + std::string module; + std::string symbol; + std::string offset; + std::string address; + + explicit SymbolInfo(const std::string& raw) { + // Format: module(symbol+offset) [address] + // or: module() [address] + // or: module [address] + + std::regex re(R"(^(.*?)\((.*?)\)\s+\[(.*?)\]$)"); + std::smatch match; + + if (std::regex_match(raw, match, re)) { + module = match[1]; + address = match[3]; + + // Parse symbol+offset + std::string sym_off = match[2]; + size_t plus_pos = sym_off.find('+'); + if (plus_pos != std::string::npos) { + symbol = sym_off.substr(0, plus_pos); + offset = sym_off.substr(plus_pos); + } else { + symbol = sym_off; + } + } else { + // Fallback: try to at least get module + size_t bracket_pos = raw.find('['); + if (bracket_pos != std::string::npos) { + module = raw.substr(0, bracket_pos); + // Trim whitespace + while (!module.empty() && std::isspace(module.back())) { + module.pop_back(); + } + } + } + } +}; + +// Extract basename from path +std::string basename(const std::string& path) { + size_t pos = path.rfind('/'); + return (pos != std::string::npos) ? path.substr(pos + 1) : path; +} + +// Try to extract exception message if it's a std::exception +std::string try_get_exception_message(void* thrown_exception, std::type_info* tinfo) { + try { + if (thrown_exception && tinfo) { + // This is hacky but sometimes works for std::exception derivatives + auto* e = static_cast(thrown_exception); + if (e) { + const char* msg = e->what(); + if (msg) { + return std::string(msg); + } + } + } + } catch (...) { + // Ignore any errors in trying to get the message + } + return ""; +} + +} // anonymous namespace + +// Backtrace::Frame implementation +std::string Backtrace::Frame::to_string() const { + std::ostringstream oss; + + oss << " #" << std::setw(2) << index << " "; + + if (has_source_info()) { + // Format: function at file:line + oss << function_name << " at " << source_file << ":" << line_number; + } else if (!function_name.empty()) { + oss << function_name; + if (!offset.empty()) { + oss << " " << offset; + } + if (!module_name.empty()) { + oss << " (" << module_name << ")"; + } + } else if (!raw_symbol.empty()) { + oss << raw_symbol; + } else { + oss << "[" << address << "]"; + } + + return oss.str(); +} + +// Backtrace implementation +void Backtrace::capture(void* thrown_exception, std::type_info* tinfo) { + clear(); + + // Capture exception info if provided + if (tinfo) { + exception_info.type_name = demangle(tinfo->name()); + exception_info.what_message = try_get_exception_message(thrown_exception, tinfo); + } + + // Capture stack + void* array[100]; + int size = backtrace(array, 100); + + if (size <= 0) { + return; + } + + char** symbols = backtrace_symbols(array, size); + if (!symbols) { + return; + } + + // Try to get the main executable path for addr2line + std::string main_executable; + Dl_info main_info; + if (dladdr(array[0], &main_info) && main_info.dli_fname) { + main_executable = main_info.dli_fname; + } + + // Start from 1 to skip __wrap___cxa_throw itself + for (int i = 1; i < size; ++i) { + Frame frame; + frame.index = i; + frame.address = array[i]; + frame.raw_symbol = symbols[i]; + + // Parse the raw symbol + SymbolInfo sym_info(symbols[i]); + + // Try to get detailed info via dladdr + Dl_info info; + bool has_dlinfo = dladdr(array[i], &info); + + // First, try addr2line for file:line info + bool addr2line_success = false; + if (!main_executable.empty()) { + Addr2LineResult addr_result = addr2line(main_executable, array[i]); + if (addr_result.success) { + frame.function_name = addr_result.function; + frame.source_file = addr_result.file; + frame.line_number = addr_result.line; + frame.module_name = basename(main_executable); + addr2line_success = true; + } + } + + // If addr2line didn't work, use dladdr info + if (!addr2line_success) { + if (has_dlinfo) { + if (info.dli_sname) { + frame.function_name = demangle(info.dli_sname); + + // Calculate offset + if (info.dli_saddr) { + auto offset = static_cast(array[i]) - static_cast(info.dli_saddr); + if (offset != 0) { + std::ostringstream oss; + oss << "+ 0x" << std::hex << offset; + frame.offset = oss.str(); + } + } + } + + if (info.dli_fname) { + frame.module_name = basename(info.dli_fname); + } + } else if (!sym_info.symbol.empty()) { + // Use parsed symbol info + frame.function_name = demangle(sym_info.symbol.c_str()); + frame.offset = sym_info.offset; + if (!sym_info.module.empty()) { + frame.module_name = basename(sym_info.module); + } + } + } + + frames.push_back(frame); + } + + free(symbols); +} + +std::string Backtrace::to_string() const { + if (frames.empty()) { + return "[No backtrace available]\n"; + } + + std::ostringstream oss; + + // Header + oss << "\n" << std::string(80, '=') << "\n"; + + if (!exception_info.type_name.empty()) { + oss << "EXCEPTION: " << exception_info.type_name << "\n"; + + if (!exception_info.what_message.empty()) { + oss << "WHAT: " << exception_info.what_message << "\n"; + } + + // Format timestamp + auto time_t = std::chrono::system_clock::to_time_t(exception_info.timestamp); + auto ms = std::chrono::duration_cast( + exception_info.timestamp.time_since_epoch()) % 1000; + + char time_buf[100]; + std::strftime(time_buf, sizeof(time_buf), "%Y-%m-%d %H:%M:%S", + std::localtime(&time_t)); + + oss << "TIME: " << time_buf << "." << std::setfill('0') << std::setw(3) << ms.count() << "\n"; + oss << "THREAD: 0x" << std::hex << exception_info.thread_id << std::dec << "\n"; + oss << std::string(80, '-') << "\n"; + } + + oss << "BACKTRACE:\n"; + + for (const auto& frame : frames) { + oss << frame.to_string() << "\n"; + } + + oss << std::string(80, '=') << "\n"; + + return oss.str(); +} + +std::string Backtrace::to_simple_string() const { + std::ostringstream oss; + for (const auto& frame : frames) { + if (!frame.function_name.empty()) { + oss << frame.function_name << "\n"; + } + } + return oss.str(); +} + +std::vector Backtrace::filter_by_module(const std::string& module_pattern) const { + std::vector filtered; + for (const auto& frame : frames) { + if (frame.module_name.find(module_pattern) != std::string::npos) { + filtered.push_back(frame); + } + } + return filtered; +} + +std::vector Backtrace::with_source_info() const { + std::vector filtered; + for (const auto& frame : frames) { + if (frame.has_source_info()) { + filtered.push_back(frame); + } + } + return filtered; +} + +#ifndef NDEBUG +extern "C" { + void __real___cxa_throw(void*, std::type_info*, void(*)(void*)); + + void __wrap___cxa_throw(void* thrown_exception, std::type_info* tinfo, void(*destructor)(void*)) { + // Capture the backtrace + g_last_throw_backtrace.capture(thrown_exception, tinfo); + + // Optionally print immediately for debugging + // Uncomment the next line if you want immediate output + // std::cerr << g_last_throw_backtrace.to_string() << std::flush; + + // Call the real throw + __real___cxa_throw(thrown_exception, tinfo, destructor); + } +} +#endif + +} // namespace autocog::utilities diff --git a/libs/autocog/utilities/exception.hxx b/libs/autocog/utilities/exception.hxx new file mode 100644 index 0000000..d8d117e --- /dev/null +++ b/libs/autocog/utilities/exception.hxx @@ -0,0 +1,138 @@ +#ifndef AUTOCOG_UTILITIES_HXX +#define AUTOCOG_UTILITIES_HXX + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace autocog::utilities { + +// Forward declaration +struct Backtrace; + +struct InternalError : std::exception { + std::string message; + + InternalError(std::string msg); + + const char * what() const noexcept override; +}; + +// Structured backtrace information +struct Backtrace { + struct Frame { + int index; // Frame number + void* address; // Raw address + std::string function_name; // Demangled function name + std::string module_name; // Library/executable name + std::string source_file; // Source file (if available) + int line_number; // Line number (if available) + std::string offset; // Offset from function start + std::string raw_symbol; // Original symbol (fallback) + + Frame() : index(0), address(nullptr), line_number(-1) {} + + // Check if we have source location info + bool has_source_info() const { + return !source_file.empty() && line_number > 0; + } + + // Format a single frame + std::string to_string() const; + }; + + // Exception information + struct ExceptionInfo { + std::string type_name; // Demangled exception type + std::string what_message; // Exception what() if available + std::thread::id thread_id; // Thread that threw + std::chrono::system_clock::time_point timestamp; + + ExceptionInfo() : thread_id(std::this_thread::get_id()), + timestamp(std::chrono::system_clock::now()) {} + }; + + ExceptionInfo exception_info; + std::vector frames; + + // Capture the current backtrace + void capture(void* thrown_exception = nullptr, std::type_info* tinfo = nullptr); + + // Format the entire backtrace as a string + std::string to_string() const; + + // Get a simplified format (just function names) + std::string to_simple_string() const; + + // Get only frames from a specific module + std::vector filter_by_module(const std::string& module_pattern) const; + + // Get only frames with source info + std::vector with_source_info() const; + + // Check if backtrace is empty + bool empty() const { return frames.empty(); } + + // Get number of frames + size_t size() const { return frames.size(); } + + // Clear the backtrace + void clear() { + frames.clear(); + exception_info = ExceptionInfo(); + } +}; + +// Global thread-local storage for last exception backtrace +extern thread_local Backtrace g_last_throw_backtrace; + +extern "C" { + void __cxa_throw(void* thrown_exception, std::type_info* tinfo, void(*destructor)(void*)); +} + +template +std::optional wrap_exception(Callable&& callable, Args&&... args) { + try { + return std::invoke(std::forward(callable), + std::forward(args)...); + + } catch (InternalError const & e) { + std::cerr << "Internal error: " << e.what() << "\n"; + if (!g_last_throw_backtrace.empty()) { + std::cerr << g_last_throw_backtrace.to_string() << "\n"; + } else { + std::cerr << ">> Backtrace is empty!!! Rethrow for debugging... <<" << "\n"; + throw; + } + return 250; + + } catch (std::exception const & e) { + std::cerr << "Uncaught exception: " << e.what() << "\n"; + if (!g_last_throw_backtrace.empty()) { + std::cerr << g_last_throw_backtrace.to_string() << "\n"; + } else { + std::cerr << ">> Backtrace is empty!!! Rethrow for debugging... <<" << "\n"; + throw; + } + return 251; + + } catch (...) { + std::cerr << "Uncaught unknown exception\n"; + if (!g_last_throw_backtrace.empty()) { + std::cerr << g_last_throw_backtrace.to_string() << "\n"; + } else { + std::cerr << ">> Backtrace is empty!!! Rethrow for debugging... <<" << "\n"; + throw; + } + return 252; + } +} + +} // namespace autocog::utilities + +#endif /* AUTOCOG_UTILITIES_HXX */ diff --git a/autocog/__init__.py b/modules/autocog/__init__.py similarity index 100% rename from autocog/__init__.py rename to modules/autocog/__init__.py diff --git a/autocog/__main__.py b/modules/autocog/__main__.py similarity index 100% rename from autocog/__main__.py rename to modules/autocog/__main__.py diff --git a/modules/autocog/arch/__init__.py b/modules/autocog/arch/__init__.py new file mode 100644 index 0000000..b33994c --- /dev/null +++ b/modules/autocog/arch/__init__.py @@ -0,0 +1,7 @@ + +from .architecture import CognitiveArchitecture as CogArch +from .orchestrator import Serial, Async + +Serial.model_rebuild() +Async.model_rebuild() +CogArch.model_rebuild() diff --git a/autocog/arch/architecture.py b/modules/autocog/arch/architecture.py similarity index 99% rename from autocog/arch/architecture.py rename to modules/autocog/arch/architecture.py index 3a82c0b..586e089 100644 --- a/autocog/arch/architecture.py +++ b/modules/autocog/arch/architecture.py @@ -109,3 +109,4 @@ def toGraphViz(self): dotstr += cog.toGraphViz() + "\n" dotstr += "}\n" return dotstr + diff --git a/autocog/arch/cogs.py b/modules/autocog/arch/cogs.py similarity index 83% rename from autocog/arch/cogs.py rename to modules/autocog/arch/cogs.py index 0ac6506..41b17ae 100644 --- a/autocog/arch/cogs.py +++ b/modules/autocog/arch/cogs.py @@ -82,9 +82,19 @@ async def __call__(self, __page: Optional[Page]=None, **inputs) -> Any: fta = sta.instantiate(syntax=self.arch.syntax, frame=frame, branches=__page.branches[ptag], inputs=inputs) __page.ftas[ptag].append(fta) fta.simplify() - ftt = fta.greedy(lm=self.arch.lm) - __page.ftts[ptag].append(ftt) - next = sta.parse(lm=self.arch.lm, syntax=self.arch.syntax, stacks=__page.stacks, ftt=ftt) + if hasattr(self.arch.lm, 'evaluate'): + (ftt, paths) = self.arch.lm.evaluate(fta) + (text, proba) = paths[0] + __page.ftts[ptag].append(ftt) + else: + ftt = fta.greedy(lm=self.arch.lm) + __page.ftts[ptag].append(ftt) + results = ftt.results(lm=lm, normalized=True) + # for r,res in enumerate(results): + # lines = res[0].split('\nstart:\n')[2].split('\n') + # print(f"[{r}]\n> " + "\n> ".join(lines) + f"\n[/{r}]") + text = results[-1][0] + next = sta.parse(lm=self.arch.lm, syntax=self.arch.syntax, stacks=__page.stacks, text=text) if isinstance(next, Return): if len(next.fields) == 1 and '_' in next.fields: return frame.read(next.fields['_']) @@ -106,3 +116,4 @@ def toGraphViz(self): for (tag,prompt) in self.prompts: dotstr += prompt.toGraphViz_abstract() return dotstr + diff --git a/autocog/arch/orchestrator.py b/modules/autocog/arch/orchestrator.py similarity index 98% rename from autocog/arch/orchestrator.py rename to modules/autocog/arch/orchestrator.py index 7612809..d388496 100644 --- a/autocog/arch/orchestrator.py +++ b/modules/autocog/arch/orchestrator.py @@ -15,7 +15,7 @@ class Orchestrator(BaseModel): pages: List[Page] - cogs: Dict[str,Cog] = {} + cogs: Dict[str,"Cog"] = {} def __init__(self): super().__init__(pages=[ Page.root() ]) @@ -55,3 +55,4 @@ async def execute(self, jobs:List[Tuple[str,str,Any]], parent:int, progress:bool gather = asyncio.gather return await gather(*super().coroutines(jobs, parent)) + diff --git a/autocog/arch/utility.py b/modules/autocog/arch/utility.py similarity index 100% rename from autocog/arch/utility.py rename to modules/autocog/arch/utility.py diff --git a/autocog/arch/__init__.py b/modules/autocog/compiler/__init__.py similarity index 100% rename from autocog/arch/__init__.py rename to modules/autocog/compiler/__init__.py diff --git a/modules/autocog/compiler/stl/__init__.py b/modules/autocog/compiler/stl/__init__.py new file mode 100644 index 0000000..38b64b1 --- /dev/null +++ b/modules/autocog/compiler/stl/__init__.py @@ -0,0 +1,6 @@ + +try: + from .stl_cxx import * +except ImportError as e: + raise ImportError(f"Failed to import autocog.compiler.stl C++ extension: {e}") + diff --git a/autocog/config.py b/modules/autocog/config.py similarity index 100% rename from autocog/config.py rename to modules/autocog/config.py diff --git a/autocog/fta/__init__.py b/modules/autocog/fta/__init__.py similarity index 100% rename from autocog/fta/__init__.py rename to modules/autocog/fta/__init__.py diff --git a/autocog/fta/actions.py b/modules/autocog/fta/actions.py similarity index 84% rename from autocog/fta/actions.py rename to modules/autocog/fta/actions.py index 3923f3c..cbdfb7b 100644 --- a/autocog/fta/actions.py +++ b/modules/autocog/fta/actions.py @@ -11,8 +11,6 @@ class Action(BaseModel): uid: str successors: List[str] = [] - width: Optional[int] = None - threshold: Optional[float] = None def __init__(self, uid:Optional[str]=None, **kwargs): if uid is None: @@ -41,6 +39,7 @@ def toGraphVizNode(self, label_with_uid:bool=False): class Text(Action): text: str tokens: List[Token] = [] + evaluate: Optional[bool] = None def __init__(self, uid:str, text:str, successors: List[str]=[]): super().__init__(uid=uid, successors=successors, text=text) @@ -63,8 +62,11 @@ def toGraphVizLabel(self): class Choose(Action): choices: List[Tuple[str,List[Token]]] - def __init__(self, uid:str, choices:List[str], successors: List[str]=[], width:Optional[int]=None): - super().__init__(uid=uid, successors=successors, choices=[ ( c, [] ) for c in choices ], width=None if width == 0 else width) + width: Optional[int] = None + threshold: Optional[float] = None + + def __init__(self, uid:str, choices:List[str], successors: List[str]=[], width:Optional[int]=None, threshold: Optional[float]=None): + super().__init__(uid=uid, successors=successors, choices=[ ( c, [] ) for c in choices ], width=None if width == 0 else width, threshold=threshold) def prepare(self, lm): for choice in self.choices: @@ -84,15 +86,23 @@ def toGraphVizLabel(self): class Complete(Action): length: int = 1 - beams: int = 1 - ahead: int = 1 stop: str = '' + threshold: Optional[float] + beams: Optional[int] + ahead: Optional[int] + width: Optional[int] + seeds: Optional[List[str]] vocab: Vocab - def __init__(self, uid:str, length:int, stop: str='', seeds: Optional[List[str]] = None, successors: List[str]=[]): - super().__init__(uid=uid, successors=successors, length=length, stop=stop, seeds=seeds, vocab=Vocab()) + def __init__(self, uid:str, + length:int, stop: str='', + threshold=None, beams=None, ahead=None, width=None, + seeds: Optional[List[str]] = None, + successors: List[str]=[] + ): + super().__init__(uid=uid, successors=successors, length=length, stop=stop, threshold=threshold, beams=beams, ahead=ahead, width=width, seeds=seeds, vocab=Vocab()) def prepare(self, lm): if self.seeds is not None: @@ -109,3 +119,4 @@ def toGraphVizShape(self): def toGraphVizLabel(self): return f"length={self.length}\nvocab={self.vocab.toGraphVizLabel() if self.vocab is not None else ''}\nstop={self.stop}\n" + diff --git a/autocog/fta/automaton.py b/modules/autocog/fta/automaton.py similarity index 99% rename from autocog/fta/automaton.py rename to modules/autocog/fta/automaton.py index 61b2a95..b5a3327 100644 --- a/autocog/fta/automaton.py +++ b/modules/autocog/fta/automaton.py @@ -53,6 +53,7 @@ def simplify(self): pred.successors.clear() pred.successors.extend(curr.successors) del self.actions[cuid] + return self def greedy_rec(self, ptree:FiniteTokenTree, lm:LM, tokens:List[Token], action:Action): todos = [] diff --git a/autocog/fta/beam.py b/modules/autocog/fta/beam.py similarity index 100% rename from autocog/fta/beam.py rename to modules/autocog/fta/beam.py diff --git a/autocog/fta/ftt.py b/modules/autocog/fta/ftt.py similarity index 98% rename from autocog/fta/ftt.py rename to modules/autocog/fta/ftt.py index e273094..245860f 100644 --- a/autocog/fta/ftt.py +++ b/modules/autocog/fta/ftt.py @@ -217,7 +217,7 @@ def toGraphViz(self, lm): label = label.replace('\n',r'\l') # label = json.dumps(label.replace(r'\n',r'\l'))[1:-1] # label = 'text' - dotstr += f'n_{tree.id}' + '[shape=record, label="{' + label + '\l|finalized=' + str(tree.finalized) + '\l' + ( "" if tree.probas is None else tree.probas.toGraphVizRecord() ) + '}"];\n' + dotstr += f'n_{tree.id}' + '[shape=record, label="{' + label + '\\l|finalized=' + str(tree.finalized) + '\\l' + ( "" if tree.probas is None else tree.probas.toGraphVizRecord() ) + '}"];\n' if tree.parent is not None: dotstr += f'n_{tree.parent.id} -> n_{tree.id};\n' return dotstr diff --git a/autocog/fta/tct.py b/modules/autocog/fta/tct.py similarity index 100% rename from autocog/fta/tct.py rename to modules/autocog/fta/tct.py diff --git a/autocog/fta/utils.py b/modules/autocog/fta/utils.py similarity index 100% rename from autocog/fta/utils.py rename to modules/autocog/fta/utils.py diff --git a/autocog/fta/vocab.py b/modules/autocog/fta/vocab.py similarity index 100% rename from autocog/fta/vocab.py rename to modules/autocog/fta/vocab.py diff --git a/autocog/sta/__init__.py b/modules/autocog/llama/__init__.py similarity index 100% rename from autocog/sta/__init__.py rename to modules/autocog/llama/__init__.py diff --git a/modules/autocog/llama/xfta/__init__.py b/modules/autocog/llama/xfta/__init__.py new file mode 100644 index 0000000..55e956b --- /dev/null +++ b/modules/autocog/llama/xfta/__init__.py @@ -0,0 +1,23 @@ + +from .xfta_cxx import create, tokenize, detokenize, vocab_size, evaluate +from .convert import fta_to_cxx, cxx_to_ftt +from .vocabs import create_safe_vocab_mask, create_single_char_vocab + +fta_defaults = { + "Text" : { + "evaluate" : True + }, + "Choose" : { + "threshold" : 0.4, + "width" : 2 + }, + "Complete" : { + "threshold" : 0.3, + "width" : 2, + "beams" : 3, + "ahead" : 1, + "diversity" : 1., + "repetition" : .5 + } +} + diff --git a/modules/autocog/llama/xfta/convert.py b/modules/autocog/llama/xfta/convert.py new file mode 100644 index 0000000..e02dca9 --- /dev/null +++ b/modules/autocog/llama/xfta/convert.py @@ -0,0 +1,58 @@ + +import math + +from .xfta_cxx import tokenize, detokenize + +from ...fta.actions import Choose, Text, Complete + +def fta_to_cxx(model, fta, defaults, safe_mask, char_mask): + actions = [] + for act in fta.actions.values(): + action = { "uid" : act.uid, "successors" : act.successors } + if isinstance(act, Text): + action.update({ + "__type__" : "Text", + "evaluate" : defaults["Text"]["evaluate"] if act.evaluate is None else act.evaluate, + "tokens" : tokenize(model, act.text, False, False) + }) + elif isinstance(act, Choose): + if len(act.successors) == 1: + action.update({ "successors" : [ act.successors[0] ] * len(act.choices) }) + assert len(act.successors) == 0 or len(action['successors']) == len(act.choices) + action.update({ + "__type__" : "Choose", + "choices" : [ + tokenize(model, choice[0], False, False) for choice in act.choices + ], + "threshold" : defaults["Choose"]["threshold"] if act.threshold is None else act.threshold, + "width" : defaults["Choose"]["width"] if act.width is None else act.width + }) + elif isinstance(act, Complete): + assert act.length is not None + action.update({ + "__type__" : "Complete", + "length" : act.length, + "stop" : tokenize(model, act.stop, False, False), + "vocab" : safe_mask, + "threshold" : defaults["Complete"]["threshold"] if act.threshold is None else act.threshold, + "beams" : defaults["Complete"]["beams"] if act.beams is None else act.beams, + "ahead" : defaults["Complete"]["ahead"] if act.ahead is None else act.ahead, + "width" : defaults["Complete"]["width"] if act.width is None else act.width, + }) + else: + raise Exception() + actions.append( action ) + return { 'actions' : actions } + +def cxx_to_ftt(model, ftt): + return { + "action" : ftt["action"], + "text" : detokenize(model, ftt["tokens"], False, False), + "length" : ftt["length"], + "logprobs" : ftt["logprobs"], + "probability" : math.exp(-ftt["logprob"]/ftt["length"]), + "locproba" : math.exp(-sum(ftt["logprobs"])/len(ftt["logprobs"])), + "children" : [ cxx_to_ftt(model, child) for child in ftt["children"] ], + "pruned" : ftt["pruned"], + } + diff --git a/modules/autocog/llama/xfta/vocabs.py b/modules/autocog/llama/xfta/vocabs.py new file mode 100644 index 0000000..3e33eb8 --- /dev/null +++ b/modules/autocog/llama/xfta/vocabs.py @@ -0,0 +1,51 @@ + +import string +from .xfta_cxx import tokenize, detokenize, vocab_size + +valid_char = ( + string.ascii_letters + # a-z, A-Z + string.digits + # 0-9 + string.punctuation + # .,!? etc. + ' ' + '\n' # spaces +) + +def create_safe_vocab_mask(model): + vsize = vocab_size(model) + mask = [True] * vsize + + for token_id in range(vsize): + try: + text = detokenize(model, [token_id], False, False) + is_problematic = ( + '\n' in text or # Newlines + '\r' in text or # Carriage returns + '\t' in text or # Tabs + len(text) == 0 or # Empty tokens + any(ord(c) < 32 for c in text) or # Control characters + any(ord(c) > 126 for c in text) # Non-ASCII + ) + + if is_problematic and not text in valid_char: + mask[token_id] = False + + except Exception: + mask[token_id] = False + + return mask + +def create_single_char_vocab(model): + vsize = vocab_size(model) + mask = [False] * vsize + + for char in valid_char: + try: + tokens = tokenize(model, char, False, False) + if len(tokens) == 1: + token_id = tokens[0] + if 0 <= token_id < vsize: + mask[token_id] = True + except Exception: + pass + + return mask + diff --git a/modules/autocog/lm/__init__.py b/modules/autocog/lm/__init__.py new file mode 100644 index 0000000..c310564 --- /dev/null +++ b/modules/autocog/lm/__init__.py @@ -0,0 +1,5 @@ + +from .random import RLM +from .llama_cxx import LlamaCXX +# from .llama_py import LlamaPY +# from .transformers import TfLM diff --git a/modules/autocog/lm/llama_cxx.py b/modules/autocog/lm/llama_cxx.py new file mode 100644 index 0000000..61cc007 --- /dev/null +++ b/modules/autocog/lm/llama_cxx.py @@ -0,0 +1,55 @@ + +from typing import Any, Dict, List, Tuple, Union, Optional, Callable +from .lm import LM + +from ..llama.xfta import create as xfta_create +from ..llama.xfta import tokenize as xfta_tokenize +from ..llama.xfta import detokenize as xfta_detokenize +from ..llama.xfta import vocab_size as xfta_vocab_size +from ..llama.xfta import evaluate as xfta_evaluate + +from ..llama.xfta import fta_defaults, fta_to_cxx, cxx_to_ftt +from ..llama.xfta import create_safe_vocab_mask, create_single_char_vocab + +def extract_paths_from_ftt(ftt, current_path=""): + paths = [] + new_path = current_path + ftt["text"] + + if not "children" in ftt or len(ftt["children"]) == 0: + if not ftt["pruned"]: + paths.append((new_path, ftt["probability"])) + else: + for child in ftt["children"]: + child_paths = extract_paths_from_ftt(child, new_path) + paths.extend(child_paths) + return sorted(paths, key=lambda path: path[1], reverse=True) + +class LlamaCXX(LM): + model: Any + safe_mask: Any + char_mask: Any + + def __init__(self, model_path:str, n_ctx=4096, **kwargs): + model = xfta_create(model_path, n_ctx) if len(model_path) > 0 else 0 + super().__init__( + model=model, + safe_mask=create_safe_vocab_mask(model), + char_mask=create_single_char_vocab(model) + ) + + def tokenize(self, text:str, whole:bool=True) -> List[int]: + raise NotImplementedError("LlamaCXX.tokenize") + + def detokenize(self, tokens:List[int], whole:bool=True) -> str: + raise NotImplementedError("LlamaCXX.detokenize") + + def evaluate(self, fta): + fta = fta_to_cxx(self.model, fta, fta_defaults, self.safe_mask, self.char_mask) + ftt = xfta_evaluate(self.model, fta) + ftt = cxx_to_ftt(self.model, ftt) + paths = extract_paths_from_ftt(ftt) + return (ftt, paths) + + def impl_greedy(self, **kwargs): + raise Exception("LlamaCXX does not implement `impl_greedy`") + diff --git a/autocog/lm/llama.py b/modules/autocog/lm/llama_py.py similarity index 98% rename from autocog/lm/llama.py rename to modules/autocog/lm/llama_py.py index afef34d..f9ffe08 100644 --- a/autocog/lm/llama.py +++ b/modules/autocog/lm/llama_py.py @@ -8,7 +8,7 @@ llama_cpp = "Package `llama_cpp` needed for LLaMa wrapper (pip install git+https://github.com/tristanvdb/llama-cpp-python@choice-dev)" print(f"Warning: {llama_cpp}") -class Llama(LM): +class LlamaPY(LM): model: Any def __init__(self, model_path:str, logits_all=True, verbose=False, n_ctx=2048, **kwargs): @@ -51,3 +51,4 @@ def detokenize(self, tokens:List[int], whole:bool=True) -> str: def impl_greedy(self, prompt: Union[str,List[int]]) -> List[float]: output = self.model.create_completion(prompt, max_tokens=1, logprobs=-1, full_logprobs=True) return output['choices'][0]['logprobs'][0] + diff --git a/autocog/lm/lm.py b/modules/autocog/lm/lm.py similarity index 99% rename from autocog/lm/lm.py rename to modules/autocog/lm/lm.py index f1e10c1..3b722e5 100644 --- a/autocog/lm/lm.py +++ b/modules/autocog/lm/lm.py @@ -50,3 +50,4 @@ def greedy(self, prompt: Union[str,List[int]]): params = f"retries={self.retries}, delta={self.delta}s, growth={self.growth}x" errors = '\n - '.join(list(set(map(str,errors)))) raise Exception(f"Persisting exception when calling {self.__class__.__name__}.greedy()\n => {params}\n - {errors}") + diff --git a/autocog/lm/random.py b/modules/autocog/lm/random.py similarity index 100% rename from autocog/lm/random.py rename to modules/autocog/lm/random.py diff --git a/autocog/lm/transformers.py b/modules/autocog/lm/transformers.py similarity index 100% rename from autocog/lm/transformers.py rename to modules/autocog/lm/transformers.py diff --git a/autocog/utility/__init__.py b/modules/autocog/sta/__init__.py similarity index 100% rename from autocog/utility/__init__.py rename to modules/autocog/sta/__init__.py diff --git a/autocog/sta/ast.py b/modules/autocog/sta/ast.py similarity index 100% rename from autocog/sta/ast.py rename to modules/autocog/sta/ast.py diff --git a/autocog/sta/automaton.py b/modules/autocog/sta/automaton.py similarity index 98% rename from autocog/sta/automaton.py rename to modules/autocog/sta/automaton.py index c04a581..5d6f2c4 100644 --- a/autocog/sta/automaton.py +++ b/modules/autocog/sta/automaton.py @@ -119,7 +119,7 @@ def prompt(self, syntax): prompt += fmt if syntax.prompt_with_index: idx = self.indices[-1] - if not prompt_zero_index: + if not syntax.prompt_zero_index: idx += 1 idx = f'[{idx}]' if field.is_list() else '' prompt += idx @@ -385,7 +385,7 @@ def instantiate_rec(self, syntax: Syntax, frame: Frame, fta: FTA, concrete: Conc path = [ ( p.name, i if p.is_list() else None ) for (p,i) in zip(successor.abstract.parents(), successor.indices) ] fta.create(uid=uid, cls=Text, text=frame.read(path)) elif isinstance(fmt, IrCompletion): - fta.create(uid=uid, cls=Complete, length=fmt.length, stop='\n') + fta.create(uid=uid, cls=Complete, length=fmt.length, threshold=fmt.threshold, beams=fmt.beams, ahead=fmt.ahead, width=fmt.width, stop='\n') elif isinstance(fmt, IrEnum): fta.create(uid=uid, cls=Choose, choices=fmt.values, width=fmt.width) elif isinstance(fmt, IrChoice): @@ -538,14 +538,8 @@ def instantiate(self, syntax: Syntax, frame: Any, branches: Any, inputs: Any): fta.connect('next.field', 'next.choice') return fta - def parse(self, lm:LM, ftt:FTT, syntax: Syntax, stacks: Any): + def parse(self, lm:LM, text:str, syntax: Syntax, stacks: Any): result = None - - results = ftt.results(lm=lm, normalized=True) - # for r,res in enumerate(results): - # lines = res[0].split('\nstart:\n')[2].split('\n') - # print(f"[{r}]\n> " + "\n> ".join(lines) + f"\n[/{r}]") - text = results[-1][0] lines = text.split('\nstart:\n')[2].split('\n') # print("[Lines]\n> " + "\n> ".join(lines) + "\n[/Lines]") diff --git a/autocog/sta/compile.py b/modules/autocog/sta/compile.py similarity index 91% rename from autocog/sta/compile.py rename to modules/autocog/sta/compile.py index e2520af..c6d7220 100644 --- a/autocog/sta/compile.py +++ b/modules/autocog/sta/compile.py @@ -66,14 +66,24 @@ def resolve_type(type: Union[AstRecord,AstTypeRef,AstEnumType], path:List[str], if type.name == 'text': length = None - if len(type.arguments) == 1: - arg = type.arguments[0] - if arg.name is not None and arg.name != 'length': - raise Exception(f"Builtin format `text` expect only `length` arguments (got: {args})") - length = arg.value.eval(values=values) - elif len(type.arguments) > 1: - raise Exception(f"Builtin format `text` expect single `length` arguments (got: {args})") - fmt = IrCompletion(name=pathname, length=length) + threshold = None + beams = None + ahead = None + width = None + for arg in type.arguments: + if arg.name is None or arg.name == 'length': + length = arg.value.eval(values=values) + elif arg.name == 'beams': + beams = arg.value.eval(values=values) + elif arg.name == 'ahead': + ahead = arg.value.eval(values=values) + elif arg.name == 'width': + width = arg.value.eval(values=values) + elif arg.name == 'threshold': + threshold = arg.value.eval(values=values) + else: + raise Exception(f"Builtin format `text` does not expect `{arg.name}` arguments (got: {type.arguments})") + fmt = IrCompletion(name=pathname, length=length, threshold=threshold, beams=beams, ahead=ahead, width=width) program.formats.update({ pathname : fmt }) return fmt @@ -305,7 +315,7 @@ def compile_prompt(ast: AstPrompt, program: IrProgram, ctx:Context): return prompt -def compile(arch:"CogArch", tag:str, source:str): +def compile_source_to_program_and_stas(source:str): program = IrProgram() ctx = Context() @@ -323,4 +333,9 @@ def compile(arch:"CogArch", tag:str, source:str): sta.build_concrete() stas.update({ptag:sta}) + return (program, stas) + +def compile(arch:"CogArch", tag:str, source:str): + (program, stas) = compile_source_to_program_and_stas(source) return CogAutomaton(tag=tag, arch=arch, program=program, prompts=stas) + diff --git a/autocog/sta/devel.py b/modules/autocog/sta/devel.py similarity index 100% rename from autocog/sta/devel.py rename to modules/autocog/sta/devel.py diff --git a/autocog/sta/frontend.py b/modules/autocog/sta/frontend.py similarity index 100% rename from autocog/sta/frontend.py rename to modules/autocog/sta/frontend.py diff --git a/autocog/sta/grammar.py b/modules/autocog/sta/grammar.py similarity index 99% rename from autocog/sta/grammar.py rename to modules/autocog/sta/grammar.py index ad1eb44..16aeb37 100644 --- a/autocog/sta/grammar.py +++ b/modules/autocog/sta/grammar.py @@ -126,7 +126,7 @@ int_literal = ~r'\d+' int_infinty = "INF" -WS = ~"\s*" +WS = ~r'\s*' # Keywords diff --git a/autocog/sta/ir.py b/modules/autocog/sta/ir.py similarity index 74% rename from autocog/sta/ir.py rename to modules/autocog/sta/ir.py index 384b446..818051b 100644 --- a/autocog/sta/ir.py +++ b/modules/autocog/sta/ir.py @@ -97,21 +97,31 @@ def mechanics(self, mech, indent): mechs = '\n'.join(mechs) return mech + '\n```\n' + mechs + '\n```' - def formats(self, fmt, lst): + def formats(self, fmt, lst, detailed_formats): # TODO add enum, repeat, select, and text description as needed # TODO next if len(self.flows) > 0 formats = [ fld.format for fld in self.fields if fld.format is not None and fld.format.refname is not None ] if len(formats) > 0: fmtstrs = [] for f in formats: - fmtstrs.append(f"{lst}{f.label()}: {f.str()}") - fmtstrs += [ f" {lst}{desc}" for desc in f.desc ] + if detailed_formats: + fmtstrs.append(f"{lst}{f.label()}: {f.str()}") + fmtstrs += [ f" {lst}{desc}" for desc in f.desc ] + else: + fmtstr = f"{lst}{f.label()}: " + if len(f.desc) > 1: + fmtstrs.append(fmtstr) + fmtstrs += [ f" {lst}{desc}" for desc in f.desc ] + elif len(f.desc) == 1: + fmtstrs.append(f"{fmtstr}{f.desc[0]}") + else: + fmtstrs.append(fmtstr) return '\n' + fmt + '\n' + '\n'.join(fmtstrs) else: return '' - def header(self, mech:str, indent:str, fmt:str, lst:str): - return ' '.join(self.desc) + '\n' + self.mechanics(mech=mech, indent=indent) + self.formats(fmt=fmt, lst=lst) + def header(self, mech:str, indent:str, fmt:str, lst:str, detailed_formats:bool): + return ' '.join(self.desc) + '\n' + self.mechanics(mech=mech, indent=indent) + self.formats(fmt=fmt, lst=lst, detailed_formats=detailed_formats) class Program(BaseModel): desc: Optional[str] = None @@ -128,17 +138,29 @@ def toGraphViz(self): class Completion(Format): length: Optional[int] + threshold: Optional[float] + beams: Optional[int] + ahead: Optional[int] + width: Optional[int] within: Optional[List[str]] = None def str(self): - res = 'text' + res = [] if self.length is not None: - res += f'({self.length})' - return res + res.append(f'length={self.length}') + if self.threshold is not None: + res.append(f'threshold={self.threshold}') + if self.beams is not None: + res.append(f'beams={self.beams}') + if self.ahead is not None: + res.append(f'ahead={self.ahead}') + if self.width is not None: + res.append(f'width={self.width}') + return 'text<' + ','.join(res) + '>' class Enum(Format): values: List[str] - width: int = 0 + width: Optional[int] = None def str(self): str = '","'.join(self.values) @@ -147,7 +169,7 @@ def str(self): class Choice(Format): path: Path mode: str - width: int = 0 + width: Optional[int] = None def str(self): return f'{self.mode}({self.path.str()})' @@ -172,3 +194,4 @@ class Dataflow(Channel): class Input(Channel): src: SrcPath + diff --git a/autocog/sta/parse_tree.py b/modules/autocog/sta/parse_tree.py similarity index 100% rename from autocog/sta/parse_tree.py rename to modules/autocog/sta/parse_tree.py diff --git a/autocog/sta/runtime.py b/modules/autocog/sta/runtime.py similarity index 100% rename from autocog/sta/runtime.py rename to modules/autocog/sta/runtime.py diff --git a/autocog/sta/syntax.py b/modules/autocog/sta/syntax.py similarity index 86% rename from autocog/sta/syntax.py rename to modules/autocog/sta/syntax.py index 6583dfa..e381b99 100644 --- a/autocog/sta/syntax.py +++ b/modules/autocog/sta/syntax.py @@ -39,16 +39,18 @@ class Syntax(BaseModel): header_mechanic: str = "You are using the following syntax:" header_formats: str = "It includes the folowing named formats:" format_listing: str = "- " - prompt_indent: str = "> " + prompt_indent: str = "\t" - system_msg: str = 'You are an AI expert interacting with your environment using a set of interactive questionnaires.' + system_msg: str = 'You are an AI expert interacting with your environment using a set of interactive questionnaires. A newline ends each statement (or prompt).' header_pre: str = '' header_mid: str = '\n' header_post: str = '\n' - prompt_with_format: bool = True - prompt_with_index: bool = True + prompt_with_format: bool = False + prompt_with_index: bool = False prompt_zero_index: bool = False + + detailed_formats: bool = False @staticmethod def Llama2Chat(**kwargs): @@ -70,6 +72,8 @@ def header(self, prompt: Prompt): mech=self.header_mechanic, indent=self.prompt_indent, fmt=self.header_formats, - lst=self.format_listing + lst=self.format_listing, + detailed_formats=self.detailed_formats ) return self.header_pre + self.system_msg + self.header_mid + header + self.header_post + 'start:\n' + diff --git a/tests/cli/elementary/run.sh b/modules/autocog/utility/__init__.py old mode 100755 new mode 100644 similarity index 100% rename from tests/cli/elementary/run.sh rename to modules/autocog/utility/__init__.py diff --git a/autocog/utility/args2arch.py b/modules/autocog/utility/args2arch.py similarity index 97% rename from autocog/utility/args2arch.py rename to modules/autocog/utility/args2arch.py index a660e8a..94407fd 100644 --- a/autocog/utility/args2arch.py +++ b/modules/autocog/utility/args2arch.py @@ -26,7 +26,7 @@ def argparser(): parser.add_argument('--command', help="""Command to be executed by the architecture as a dictionary. `__tag` identify the cog while `__entry` identify the entry point in this cog (defaults to `main`). All other field will be forwarded as keyworded args. Example: `{ "__tag" : "writer", "__entry" : "main", **kwarg }` (inlined JSON or path to a file). Any command argument can be a list of dictionary.""", action='append') - parser.add_argument('--libdir', help="""Directory where results are stored.""", action='append', default=[]) + parser.add_argument('--libdir', help="""Directory where libraries are stored.""", action='append', default=[]) parser.add_argument('--output', help="""Directory where results are stored.""", default=os.getcwd()) parser.add_argument('--prefix', help="""String to identify this instance of AutoCog""", default='autocog') diff --git a/autocog/utility/dashboard.py b/modules/autocog/utility/dashboard.py similarity index 100% rename from autocog/utility/dashboard.py rename to modules/autocog/utility/dashboard.py diff --git a/autocog/utility/enums.py b/modules/autocog/utility/enums.py similarity index 100% rename from autocog/utility/enums.py rename to modules/autocog/utility/enums.py diff --git a/autocog/utility/gv2html.py b/modules/autocog/utility/gv2html.py similarity index 100% rename from autocog/utility/gv2html.py rename to modules/autocog/utility/gv2html.py diff --git a/autocog/utility/models.py b/modules/autocog/utility/models.py similarity index 73% rename from autocog/utility/models.py rename to modules/autocog/utility/models.py index a181a3a..8fc91e5 100644 --- a/autocog/utility/models.py +++ b/modules/autocog/utility/models.py @@ -1,14 +1,17 @@ from ..sta.syntax import Syntax, syntax_kwargs as SyntaxKwargs from ..lm import RLM -from ..lm import Llama -def loader(models_path=None, syntax=None, n_ctx=4096, **syntax_kwargs): +def loader(models_path=None, syntax=None, n_ctx=4096, use_cxx=True, **syntax_kwargs): if models_path is None or models_path == '': models_path = '' lm = RLM() - elif models_path.endswith('.gguf'): - lm = Llama(model_path=models_path, n_ctx=n_ctx) + elif models_path.endswith('.gguf') and use_cxx: + from ..lm import LlamaCXX + lm = LlamaCXX(model_path=models_path, n_ctx=n_ctx) + elif models_path.endswith('.gguf') and not use_cxx: + from ..lm import LlamaPY + lm = LlamaPY(model_path=models_path, n_ctx=n_ctx) else: raise Exception(f'Unrecognized model file extension: {models_path.split(".")[-1]}') @@ -34,4 +37,4 @@ def loader(models_path=None, syntax=None, n_ctx=4096, **syntax_kwargs): syntax = Syntax(**syntax) - return (lm,syntax) \ No newline at end of file + return (lm,syntax) diff --git a/autocog/utility/pynb.py b/modules/autocog/utility/pynb.py similarity index 100% rename from autocog/utility/pynb.py rename to modules/autocog/utility/pynb.py diff --git a/autocog/utility/server.py b/modules/autocog/utility/server.py similarity index 100% rename from autocog/utility/server.py rename to modules/autocog/utility/server.py diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..b44d172 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,64 @@ +[build-system] +requires = [ + "setuptools>=64", + "wheel", + "pybind11>=2.6.0", + "cmake>=3.18" +] +build-backend = "setuptools.build_meta" + +[project] +name = "autocog" +dynamic = ["version"] +description = "Automaton & Cognition: programming models for language models" +readme = "README.md" +license = {text = "Apache 2.0"} +authors = [ + {name = "Tristan Vanderbruggen", email = "vanderbrugge1@llnl.gov"} +] +classifiers = [ + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: C++", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development :: Libraries :: Python Modules", +] +requires-python = ">=3.8" +dependencies = [ + "numpy", + "pydantic>=2", + "typing_extensions", + "parsimonious", + "pybind11>=2.6.0" +] + +[project.urls] +Homepage = "https://github.com/LLNL/autocog/" + +[tool.setuptools.dynamic] +version = {file = "VERSION"} + +[tool.setuptools.packages.find] +where = ["modules"] +include = ["autocog", "autocog.*"] + +[tool.setuptools.package-dir] +"" = "modules" + +[tool.setuptools.package-data] +autocog = [ + "py.typed", + "compiler/stl/*.so", + "llama/xfta/*.so" +] + +[tool.setuptools.data-files] +"share/autocog/library/mcq" = ["share/library/mcq/*"] +"share/autocog/library/dfl" = ["share/library/dfl/*"] +"share/autocog/library/elementary" = ["share/library/elementary/*"] +"share/autocog/library/tools" = ["share/library/tools/*"] + diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 0d4a237..0000000 --- a/requirements.txt +++ /dev/null @@ -1,4 +0,0 @@ -numpy -pydantic>=2 -typing_extensions -parsimonious diff --git a/scripts/dump_sta_to_json.py b/scripts/dump_sta_to_json.py new file mode 100755 index 0000000..287b3fa --- /dev/null +++ b/scripts/dump_sta_to_json.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 + +import sys, json, html, math, string, graphviz + +from autocog.sta.compile import compile_source_to_program_and_stas +from autocog.sta.syntax import Syntax +from autocog.sta.runtime import Frame + +from autocog.fta.automaton import FiniteThoughtAutomaton as FTA + +from autocog.utility.models import loader + +from autocog.llama.xfta import fta_defaults, fta_to_cxx, create_safe_vocab_mask, create_single_char_vocab + +def main(argv): + sta_file = argv[1] + model_path = argv[2] + json_data = argv[3] if len(argv) >= 4 else '{}' + prompt_name = argv[4] if len(argv) >= 5 else 'main' + + sta = compile_source_to_program_and_stas( + open(sta_file, 'r').read() + )[1][prompt_name] + + (model, syntax) = loader(models_path=model_path, n_ctx=4096, use_cxx=True) + + fta = sta.instantiate( + syntax=syntax, + frame=Frame( + state={ st.label() : None for st in sta.concretes.values() if st.abstract.field is not None }, + data=json.loads(json_data) + ), + branches={}, + inputs=None + ).simplify() + + safe_mask = create_safe_vocab_mask(model.model) + char_mask = create_single_char_vocab(model.model) + cxx = fta_to_cxx(model.model, fta, fta_defaults, safe_mask, char_mask) + + with open(sta_file+'.json','w') as F: + json.dump(cxx, F, indent=4) + +if __name__ == '__main__': + main(sys.argv) + diff --git a/scripts/execute_sta_with_llama_cpp.py b/scripts/execute_sta_with_llama_cpp.py new file mode 100755 index 0000000..573fb9a --- /dev/null +++ b/scripts/execute_sta_with_llama_cpp.py @@ -0,0 +1,171 @@ +#!/usr/bin/env python3 + +import sys, json, html, math, string, graphviz + +from autocog.sta.compile import compile_source_to_program_and_stas +from autocog.sta.syntax import Syntax +from autocog.sta.runtime import Frame + +from autocog.fta.automaton import FiniteThoughtAutomaton as FTA + +from autocog.utility.models import loader + +from autocog.llama.xfta import tokenize, detokenize + +def main(argv): + sta_file = argv[1] + json_data = argv[2] + model_path = argv[3] if len(argv) >= 4 else '' + prompt_name = argv[4] if len(argv) >= 5 else 'main' + + sta = compile_source_to_program_and_stas( + open(sta_file, 'r').read() + )[1][prompt_name] + + (model, syntax) = loader(models_path=model_path, n_ctx=4096, use_cxx=True) + + fta = sta.instantiate( + syntax=syntax, + frame=Frame( + state={ st.label() : None for st in sta.concretes.values() if st.abstract.field is not None }, + data=json.loads(json_data) + ), + branches={}, + inputs=None + ).simplify() + + (ftt, paths) = model.evaluate(fta) + + for (text, proba) in paths: + print("========================================") + print(f"[[ {proba} ]]") + print("> " + "\n> ".join(text.split('\n'))) + + ftt_to_graphviz_detailed( + ftt, + output_filename='ftt', + format='svg', + max_text_length=100, + show_text_preview=True + ) + +def ftt_to_graphviz_detailed(ftt, output_filename='ftt_tree_detailed', format='png', max_text_length=30, show_text_preview=True): + """ + Detailed version with HTML-like labels for better formatting. + """ + dot = graphviz.Digraph(comment='FTT Tree Detailed', format=format) + dot.attr(rankdir='TB') + # Important: use 'plaintext' shape for HTML labels + dot.attr('node', shape='plaintext') + + node_counter = [0] + + def format_text_for_display(text, max_len): + """Format text for display in node, handling newlines and long text.""" + if not text: + return "[empty]" + + if len(text) > max_len: + text = text[:max_len] + "..." + + # HTML entity escaping + text = text.replace('&', '&') # Must be first + text = text.replace('<', '‹') + text = text.replace('>', '›') + text = text.replace('"', '"') + text = text.replace("'", ''') + + # GraphViz special characters + text = text.replace('\\', '\\\\') + text = text.replace('{', '\\{') + text = text.replace('}', '\\}') + text = text.replace('|', '\\|') + + # Handle newlines and control characters + text = text.replace('\n', '↵') + text = text.replace('\r', '⏎') + text = text.replace('\t', '→') + + # Remove remaining control characters + text = ''.join(c if ord(c) >= 32 or c in ['\n', '\r', '\t'] else f'\\x{ord(c):02x}' for c in text) + + return text + + def add_node_recursive(node, parent_id=None, depth=0, is_best=True): + node_id = f"node_{node_counter[0]}" + node_counter[0] += 1 + + # Extract node properties + probability = node["probability"] + locproba = node["locproba"] + text = node.get("text", "") + is_pruned = node["pruned"] + + # Format probability + if probability < 0.001: + prob_str = f"{probability:.2e}" + elif probability < 0.01: + prob_str = f"{probability:.4f}" + else: + prob_str = f"{probability:.3f}" + + if locproba < 0.001: + locprob_str = f"{locproba:.2e}" + elif locproba < 0.01: + locprob_str = f"{locproba:.4f}" + else: + locprob_str = f"{locproba:.3f}" + + # Determine colors + if is_pruned: + header_color = "#FFB6C1" # lightpink + border_color = "#FF0000" # red + elif probability > 0.01: + header_color = "#90EE90" # lightgreen + border_color = "#008000" # green + elif probability > 0.0001: + header_color = "#ADD8E6" # lightblue + border_color = "#0000FF" # blue + else: + header_color = "#FFFFE0" # lightyellow + border_color = "#FFA500" # orange + + # Format text for display + display_text = format_text_for_display(text, max_text_length) + + label = f'''< + + +
P = {prob_str}(p = {locprob_str})
{display_text}
>''' + + # Add node with HTML label + dot.node(node_id, label) + + # Add edge from parent + if parent_id is not None: + edge_style = 'dashed' if is_pruned else 'solid' + edge_color = 'red' if is_pruned else 'black' + + edge_width = "4.0" if is_best else "1.0" + + dot.edge(parent_id, node_id, style=edge_style, color=edge_color, penwidth=edge_width) + + # Recursively add children + if "children" in node: + for child in sorted(node["children"], key=lambda x: x["probability"], reverse=True): + add_node_recursive(child, node_id, depth + 1, is_best) + is_best = False + + return node_id + + # Build the tree + add_node_recursive(ftt) + + # Render + dot.render(output_filename, view=False, cleanup=True) + + return dot + +if __name__ == '__main__': + main(sys.argv) + diff --git a/scripts/sanity-check.sh b/scripts/sanity-check.sh new file mode 100755 index 0000000..a707dc0 --- /dev/null +++ b/scripts/sanity-check.sh @@ -0,0 +1,26 @@ +#!/bin/bash -ex + +# Very basic + +python3 -c "import autocog" + +xfta --help +stlc -h + +# Functional + +## xFTA + +python3 scripts/dump_sta_to_json.py tests/samples/mini.sta models/SmolLM3-Q4_K_M.gguf +xfta -m models/SmolLM3-Q4_K_M.gguf tests/samples/mini.sta.json + +## autocog.llama.xfta + +python3 scripts/execute_sta_with_llama_cpp.py tests/samples/mini.sta '{}' models/SmolLM3-Q4_K_M.gguf + +## STLC + +stlc tests/samples/defines.stl +stlc -I tests/samples/miniapp tests/samples/miniapp/main.stl +stlc tests/samples/miniapp/more.stl + diff --git a/setup.py b/setup.py old mode 100755 new mode 100644 index 211bfb7..f5cd078 --- a/setup.py +++ b/setup.py @@ -1,25 +1,71 @@ -import io, os, glob -from setuptools import find_packages, setup - -def read(path): - return io.open(os.path.join(os.path.dirname(__file__), path), encoding="utf8").read().strip() - -def read_requirements(path): - return list(map(lambda l: l.strip(), filter(lambda l: not l.startswith(('"', "#", "-", "git+")), read(path).split("\n")))) - -setup( - name="AutoCog", - version=read("VERSION"), - description="Automaton & Cognition: programming models for language models", - url="https://github.com/LLNL/autocog/", - long_description=read("README.md"), - long_description_content_type="text/markdown", - packages=find_packages(exclude=["share", "tests"]), - install_requires=read_requirements("requirements.txt"), - data_files=[ - ( 'share/autocog/library/mcq', glob.glob("share/library/mcq/*") ), - ( 'share/autocog/library/dfl', glob.glob("share/library/dfl/*") ), - ( 'share/autocog/library/elementary', glob.glob("share/library/elementary/*") ), - ( 'share/autocog/library/tools', glob.glob("share/library/tools/*") ) - ], -) +#!/usr/bin/env python3 + +import os +import sys +import shutil +import subprocess +from pathlib import Path +from pybind11.setup_helpers import Pybind11Extension, build_ext +from pybind11 import get_cmake_dir +import pybind11 + +from setuptools import setup, Extension + +class CMakeExtension(Extension): + def __init__(self, name, sourcedir=""): + Extension.__init__(self, name, sources=[]) + self.sourcedir = os.path.abspath(sourcedir) + +class CMakeBuild(build_ext): + def build_extension(self, ext): + extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name))) + + package_dir = Path(extdir).parent + + debug = int(os.environ.get("DEBUG", 0)) if self.debug is None else self.debug + cfg = "Debug" if debug else "Release" + + cmake_args = [ + f"-DCMAKE_INSTALL_PREFIX={package_dir}", # Install to package root + f"-DPYTHON_EXECUTABLE={sys.executable}", + f"-DCMAKE_BUILD_TYPE={cfg}", + f"-Dpybind11_DIR={get_cmake_dir()}", + ] + + build_args = ["--config", cfg] + + if "CMAKE_ARGS" in os.environ: + cmake_args += [item for item in os.environ["CMAKE_ARGS"].split(" ") if item] + + if "CMAKE_BUILD_PARALLEL_LEVEL" not in os.environ: + if hasattr(self, "parallel") and self.parallel: + build_args += [f"-j{self.parallel}"] + + build_temp = Path(self.build_temp) + build_temp.mkdir(parents=True, exist_ok=True) + + subprocess.check_call( + ["cmake", ext.sourcedir] + cmake_args, cwd=build_temp + ) + + subprocess.check_call( + ["cmake", "--build", "."] + build_args, cwd=build_temp + ) + + subprocess.check_call( + ["cmake", "--install", "."], cwd=build_temp + ) + +def main(): + ext_modules = [ + CMakeExtension("autocog._build_all", sourcedir="."), + ] + + setup( + ext_modules=ext_modules, + cmdclass={"build_ext": CMakeBuild}, + zip_safe=False, + ) + +if __name__ == "__main__": + main() diff --git a/setup.sh b/setup.sh new file mode 100755 index 0000000..f49a745 --- /dev/null +++ b/setup.sh @@ -0,0 +1,47 @@ +#!/bin/bash -ex + +# Simple AutoCog Setup Script +# Creates local .venv with llama.cpp and AutoCog + +echo "Setting up AutoCog in local .venv..." + +# Create and activate virtual environment +python3 -m venv .venv +source .venv/bin/activate + +# Upgrade pip and install build dependencies +pip install --upgrade pip setuptools wheel pybind11 numpy + +exit 0 + +# Build and install llama.cpp to .venv +echo "Building llama.cpp..." +cd vendors/llama +cmake -B build \ + -DCMAKE_INSTALL_PREFIX="$(pwd)/../../.venv" \ + -DLLAMA_BUILD_COMMON=OFF \ + -DLLAMA_LOG_DISABLE=ON \ + -DCMAKE_BUILD_TYPE=Release + +make -C build -j$(nproc) install +cd ../.. + +# Set library path for AutoCog build +export LD_LIBRARY_PATH="$(pwd)/.venv/lib:$LD_LIBRARY_PATH" +export CMAKE_PREFIX_PATH="$(pwd)/.venv" +export CMAKE_BUILD_PARALLEL_LEVEL=$(nproc) + +# Build and install AutoCog +echo "Building AutoCog..." +pip install -e . + +# Test installation +echo "Testing installation..." +python3 -c "import autocog.llama; print('Success!')" + +echo "Setup complete!" +echo "" +echo "To use AutoCog:" +echo " source .venv/bin/activate" +echo " python3 tests/autocog/llama/roundtrip_tokenization.py path/to/model.gguf" + diff --git a/share/demos/CMakeLists.txt b/share/demos/CMakeLists.txt new file mode 100644 index 0000000..3624d7d --- /dev/null +++ b/share/demos/CMakeLists.txt @@ -0,0 +1,4 @@ + + +add_subdirectory(story-writer) + diff --git a/share/demos/story-writer/CMakeLists.txt b/share/demos/story-writer/CMakeLists.txt new file mode 100644 index 0000000..f6caf8c --- /dev/null +++ b/share/demos/story-writer/CMakeLists.txt @@ -0,0 +1,4 @@ + +add_test(NAME demo-stlc-story-writer + COMMAND stlc -I ${PROJECT_SOURCE_DIR}/share/library -I ${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_SOURCE_DIR}/writer.stl) + diff --git a/tests/cli/qna/run.sh b/share/demos/story-writer/README.md old mode 100755 new mode 100644 similarity index 100% rename from tests/cli/qna/run.sh rename to share/demos/story-writer/README.md diff --git a/share/demos/story-writer/book.py b/share/demos/story-writer/book.py new file mode 100644 index 0000000..75dccda --- /dev/null +++ b/share/demos/story-writer/book.py @@ -0,0 +1,5 @@ + +def write(book): + # TODO format book to markdown + return "done" + diff --git a/share/demos/story-writer/book.stl b/share/demos/story-writer/book.stl new file mode 100644 index 0000000..3764c82 --- /dev/null +++ b/share/demos/story-writer/book.stl @@ -0,0 +1,77 @@ + +from "stlib/thoughts.stl" import Thought; +from "template.stl" import Step; + +record Sentence { + argument length=10; + is text; +} + +record Title { + is Sentence; +} + +record Page { + is { + short is Though; + illustration[1:3] is Sentence; + content[1:10] is Sentence; + } +} + +record Book { + is { + title is Title; + pages[5:100] is Page; + } +} + +prompt edit_title { + is { + topic[3:10] is text; + current is Title; + pondering[1:5] is Thought; + edited is Title; + } + channel { + topic get topic; + current get current; + } + return use edited; // return a Title (string) +} + +prompt create_pages { + is { + topic[3:10] is text; + step is Step; + prepare[1:5] is Thought; + pages[1:10] is { + imagine[1:5] is Thought; + page is Page; + } + } + channel { + topic get topic; + step get step; + } + return use pages.page; // return a list of Page +} + +prompt edit_page { + is { + task[1:10] is text; + topic[3:10] is text; + edits[1:10] is text; + current is Page; + pondering[1:5] is Thought; + edited is Page; + } + channel { + task get task; + topic get topic; + edits get edits; + current get page; + } + return use edited; // return a Page +} + diff --git a/share/demos/story-writer/template.py b/share/demos/story-writer/template.py new file mode 100644 index 0000000..eb9cee0 --- /dev/null +++ b/share/demos/story-writer/template.py @@ -0,0 +1,79 @@ + +templates = { + "bedtime" : { + "description" : [ + "Medium story to be read (many time) to young children.", + "Might take half-hour to read. Might read different part each night.", + "Nothing scary, it is bedtime. Avoid situation that could be stressful." + ], + "ages" : [3, 5], + "sentence" : "simple with few preposition", + "vocabulary" : "a couple uncommon word introduced by each story", + "steps" : [ + { + "name": "protagonist", + "description" : "present the main character of the story", + "pages" : [1,3] + }, { + "name": "perturbation" + "description" : "the event that trigger the story", + "pages" : [1,2] + },{ + "name": "explorations", + "description" : "protagonist tries to figure out a solution by trying all sort of things", + "pages" : [5,20] + },{ + "name": "resolution", + "description" : "protagonist find the solution (or get help)", + "pages" : [2,4] + },{ + "name": "conclusion", + "description" : "everything is back to normal but protagonist learned something", + "pages" : [1,2] + } + ] + }, + "wordbook" : { + "description" : [ + "Short story that introduce simple word based on a theme", + "Takes a few minutes to present to young kids", + "Older kids can sounds the word by themself" + ], + "ages" : [1, 4], + "sentence" : "single word, or very simple sentence: subject and verb maybe adjective", + "vocabulary" : "all very common words", + "steps" : [ + { + "name": "theme", + "description" : "illustrate the theme of the book", + "pages" : [1,2] + + }, { + "name": "words" + "description" : "one word or concept per page", + "pages" : [5,10] + } + ] + } +} + +def list_templates(age): + return [ + { "template" : key, "description" : tpl['description'], "ages" : tpl['ages'] } + for (key,tpl) in templates.items() + if tpl['ages'][0] <= age and age <= tpl['ages'][1] + ] + +def open_template(key): + return template[key] + +def collate_task(query, age, key): + res = [f"User ask for a book for children age {age}",f"User specified: \"{query}\""] + res += ["Template description:"] + res += [ f"> {d}" for d in templates[key]['description'] ] + res += [f"Expected sentence style: {templates[key]['sentence']}"] + res += [f"Expected vocabulary: {templates[key]['vocabulary']}"] + res += ["Template steps:"] + res += [ f"- {s['description']}" for s in templates[key]['steps'] ] + return res + diff --git a/share/demos/story-writer/template.stl b/share/demos/story-writer/template.stl new file mode 100644 index 0000000..5393cfc --- /dev/null +++ b/share/demos/story-writer/template.stl @@ -0,0 +1,35 @@ + +record Age { + is enum("1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12"); +} + +record TplDesc { + is { + template is text; + description[1:5] is text; + ages[2] is Age; + } +} + +record Step { + is { + name is text; + description is text; + pages[2] is text; + } +} + +record Template { + is { + description[1:5] is text; + ages[2] is Age; + sentence is text; + vocabulary is text; + steps[2:10] is { + name is text; + description is text; + pages[2] is text; + } + } +} + diff --git a/share/demos/story-writer/writer.stl b/share/demos/story-writer/writer.stl new file mode 100644 index 0000000..f31451c --- /dev/null +++ b/share/demos/story-writer/writer.stl @@ -0,0 +1,206 @@ + +from "stlib/thoughts.stl" import Thought, reflexion; +from "stlib/datastore.py" import store as stlib_store, retrieve as stlib_retrieve; + +from "template.stl" import Age, TplDesc, Template; +from "template.py" import list_templates, open_template, collate_task; + +from "book.stl" import Title, Book; +from "book.stl" import edit_title, create_pages, edit_page; +from "book.py" import write as book_write; + +alias Thought as BrainstormThought; + +prompt init_idea { + is { + query is text; + age is Age; + templates[1:10] is TplDesc; + pick is repeat(templates.template); + ideas[1:5] is BrainstormThought; + title is Title; + } + flow init_template; + channel { + query get query; + age get age; + templates call list_templates { + age use .age; + }; + } +} + +prompt init_task { + is { + task[1:10] is text; + template is Template; + } + channel { + task call collate_task { + query get query; + age get age; + key use init_idea.pick; + }; + template call open_template { + key get init_idea.pick; + }; + } + flow init_topic; +} + +prompt init_topic { + is { + topic[3:10] is text; // Call to reflexion produces Thought + } + channel { + topic call reflexion< + subject="From rough book idea to refined book topic.", + length=50, + mode="refined", + goal=f"guide the writing of a children book", + num_steps=10, + len_steps=20, + min_response=3, + max_response=10 + > { + task use init_task.task; + initial use init_idea.idea; + } bind(_,response); + } + flow init_draft; +} + +prompt init_draft { + is { + book is Book; + } + channel { + book.title call edit_title { + topic use init_topic.topic; + current use init_idea.title; + }; + book.pages call create_pages { + topic use init_topic.topic; + step use init_task.template.steps mapped; + } ravel; + } + flow init_commit; +} + +prompt init_commit { + is { + keys is { + pkey is text; + skey is text; + } + } + channel { + done call stlib_store { + pkey is "book"; + data use init_draft.book; + }; + } + flow loop_cond; +} + +prompt loop_cond { + argument max_loops=10; + is { + task[1:10] is text; + topic[3:10] is text; + book is Book; + comments[1:10] is Thought; + } + channel { + task use init_task.task; + topic use init_topic.topic; + book call stlib_retrieve { + pkey is "book"; + }; + } + flow { + loop_analyse[max_loops] as "edit"; + done as "done"; + } +} + +prompt loop_collate { + is { + collated[1:100] is text; + } + channel { + collated call collate_comment { + topic use loop_cond.topic; + comments use loop_cond.comments; + }; + } + flow loop_analyse; +} + +prompt loop_analyse { + is { + edits[1:10] is Thought; + } + channel { + edits call reflexion< + subject="From editor comments to elaborated edit directives.", + length=50, + mode="elaborated", + goal=f"provide clear edit guidance", + num_steps=10, + len_steps=20, + min_response=1, + max_response=10 + > { + task use loop_cond.task; + initial use loop_collate.collated; + }; + } + flow loop_dispatch; +} + +prompt loop_dispatch { + is { + book is Book; + } + channel { + book.title use loop_cond.book.title; + book.pages call edit_page { + task use loop_cond.task; + topic use loop_cond.topic; + edit use loop_analyse.edits; + page use loop_cond.book.pages mapped; + }; + } + flow loop_commit; +} + +prompt loop_commit { + is { + keys is { + pkey is text; + skey is text; + } + } + channel { + done call stlib_store { + pkey is "book"; + data use loop_dispatch.book; + }; + } + flow loop_cond; +} + +prompt done { + is { + done is text; + } + channel { + done call book_write { + book use loop_cond.book; + }; + } +} + +export init_idea as main; + diff --git a/share/library/elementary/README.md b/share/experiments/elementary/README.md similarity index 100% rename from share/library/elementary/README.md rename to share/experiments/elementary/README.md diff --git a/share/library/elementary/multiply-chain.sta b/share/experiments/elementary/multiply-chain.sta similarity index 100% rename from share/library/elementary/multiply-chain.sta rename to share/experiments/elementary/multiply-chain.sta diff --git a/share/library/elementary/multiply-single.sta b/share/experiments/elementary/multiply-single.sta similarity index 100% rename from share/library/elementary/multiply-single.sta rename to share/experiments/elementary/multiply-single.sta diff --git a/share/library/elementary/wip.ipynb b/share/experiments/elementary/wip.ipynb similarity index 100% rename from share/library/elementary/wip.ipynb rename to share/experiments/elementary/wip.ipynb diff --git a/share/library/mcq/README.md b/share/experiments/mcq/README.md similarity index 100% rename from share/library/mcq/README.md rename to share/experiments/mcq/README.md diff --git a/share/library/mcq/all.json b/share/experiments/mcq/all.json similarity index 100% rename from share/library/mcq/all.json rename to share/experiments/mcq/all.json diff --git a/share/library/mcq/repeat-annot.sta b/share/experiments/mcq/repeat-annot.sta similarity index 100% rename from share/library/mcq/repeat-annot.sta rename to share/experiments/mcq/repeat-annot.sta diff --git a/share/library/mcq/repeat-cot.sta b/share/experiments/mcq/repeat-cot.sta similarity index 100% rename from share/library/mcq/repeat-cot.sta rename to share/experiments/mcq/repeat-cot.sta diff --git a/share/library/mcq/repeat-hyp.sta b/share/experiments/mcq/repeat-hyp.sta similarity index 100% rename from share/library/mcq/repeat-hyp.sta rename to share/experiments/mcq/repeat-hyp.sta diff --git a/share/library/mcq/repeat-iter.sta b/share/experiments/mcq/repeat-iter.sta similarity index 100% rename from share/library/mcq/repeat-iter.sta rename to share/experiments/mcq/repeat-iter.sta diff --git a/share/library/mcq/repeat.sta b/share/experiments/mcq/repeat.sta similarity index 100% rename from share/library/mcq/repeat.sta rename to share/experiments/mcq/repeat.sta diff --git a/share/library/mcq/select-annot.sta b/share/experiments/mcq/select-annot.sta similarity index 100% rename from share/library/mcq/select-annot.sta rename to share/experiments/mcq/select-annot.sta diff --git a/share/library/mcq/select-cot.sta b/share/experiments/mcq/select-cot.sta similarity index 100% rename from share/library/mcq/select-cot.sta rename to share/experiments/mcq/select-cot.sta diff --git a/share/library/mcq/select-hyp.sta b/share/experiments/mcq/select-hyp.sta similarity index 100% rename from share/library/mcq/select-hyp.sta rename to share/experiments/mcq/select-hyp.sta diff --git a/share/library/mcq/select-iter.sta b/share/experiments/mcq/select-iter.sta similarity index 100% rename from share/library/mcq/select-iter.sta rename to share/experiments/mcq/select-iter.sta diff --git a/share/library/mcq/select.sta b/share/experiments/mcq/select.sta similarity index 100% rename from share/library/mcq/select.sta rename to share/experiments/mcq/select.sta diff --git a/share/library/mcq/wip.ipynb b/share/experiments/mcq/wip.ipynb similarity index 100% rename from share/library/mcq/wip.ipynb rename to share/experiments/mcq/wip.ipynb diff --git a/share/library/README.md b/share/library/README.md deleted file mode 100644 index e92c13b..0000000 --- a/share/library/README.md +++ /dev/null @@ -1,12 +0,0 @@ -AutoCog library of Cogs -======================= - -Currently, Cogs are either Structured Thought Automaton (STA) or Python (PY). -STAs is our cognitive programming language used to drive LLMs. -Python permits us to express simple data operation or build complex tools. - -List of libraries: - - [`mcq` (STA)](./mcq/README.md): Solving multiple choices question using a variety of pattern. - - [`dfl` (PY)](./dfl/README.md): Collection of simple data operations to facilitate implementing cognitive patterns. - - [`elementary` (STA)](./elementary/README.md): Focusses on very simple tasks at the elementary school level. - - [`tools` (PY)](./tools/README.md): Collection of tools that Cognitive application can use to interact with the world. diff --git a/share/library/dfl/README.md b/share/library/dfl/README.md deleted file mode 100644 index 2a1eb2c..0000000 --- a/share/library/dfl/README.md +++ /dev/null @@ -1,8 +0,0 @@ -Data Flow Library -================= - -This is a utility library that provides primitives to manipulate various objects: - - inserting/deleting in a list - - constructing a dictionary - - ... - diff --git a/share/library/stlib/datastore.py b/share/library/stlib/datastore.py new file mode 100644 index 0000000..34ccdce --- /dev/null +++ b/share/library/stlib/datastore.py @@ -0,0 +1,19 @@ + +data = {} + +def store(data, pkey="", skey=""): + global data + if not pkey in data: + data.update({ pkey : { skey : data } }) + else: + data[pkey].update({ skey : data }) + return { "pkey" : pkey, "skey" : skey } + +def retrieve(pkey="", skey=""): + global data + if not pkey in data: + return None + elif not skey in data[pkey]: + return None + return data[pkey][skey] + diff --git a/share/library/stlib/thoughts.stl b/share/library/stlib/thoughts.stl new file mode 100644 index 0000000..33f80f2 --- /dev/null +++ b/share/library/stlib/thoughts.stl @@ -0,0 +1,44 @@ + +define default_length = 20; +define default_mode = "simple"; +define default_goal = "show your work"; + +record Thought { + argument length=default_length; + argument mode=default_mode; + argument goal=default_goal; + + is text; + + annotate f"A {mode} thought to {goal}."; +} + +prompt reflexion { + argument subject; + + argument length = default_length; + argument mode = default_mode; + argument goal = default_goal; + + argument num_steps = 10; + argument len_steps = 20; + argument min_response = 3; + argument max_response = 10; + + is { + task[1:10] is text; + initial[1:10] is text; + work[1:num_steps] is Thought; + response[min_response:max_response] is Thought; + } + channel { + task get task; + initial get initial; + } + return { + subject is f"Though about {subject}"; + intermediate use work; + use response; + } +} + diff --git a/share/syntax-highlight/gedit/stl.lang b/share/syntax-highlight/gedit/stl.lang new file mode 100644 index 0000000..a1837e5 --- /dev/null +++ b/share/syntax-highlight/gedit/stl.lang @@ -0,0 +1,110 @@ + + + + *.stl + // + + + +