From 813db3ffe316f107180a03750b30b0b3726e1fb2 Mon Sep 17 00:00:00 2001 From: Carl Guo Date: Tue, 31 Jan 2023 17:48:26 -0500 Subject: [PATCH 01/15] init commit for sql dialect --- include/CMakeLists.txt | 1 + include/sql/CMakeLists.txt | 3 ++ include/sql/SQLDialect.h | 16 ++++++++ include/sql/SQLDialect.td | 44 +++++++++++++++++++++ include/sql/SQLOps.h | 23 +++++++++++ include/sql/SQLOps.td | 57 +++++++++++++++++++++++++++ lib/CMakeLists.txt | 1 + lib/sql/CMakeLists.txt | 19 +++++++++ lib/sql/Dialect.cpp | 27 +++++++++++++ lib/sql/Ops.cpp | 42 ++++++++++++++++++++ test/polygeist-opt/sql.mlir | 11 ++++++ tools/polygeist-opt/CMakeLists.txt | 1 - tools/polygeist-opt/polygeist-opt.cpp | 5 +++ 13 files changed, 249 insertions(+), 1 deletion(-) create mode 100644 include/sql/CMakeLists.txt create mode 100644 include/sql/SQLDialect.h create mode 100644 include/sql/SQLDialect.td create mode 100644 include/sql/SQLOps.h create mode 100644 include/sql/SQLOps.td create mode 100644 lib/sql/CMakeLists.txt create mode 100644 lib/sql/Dialect.cpp create mode 100644 lib/sql/Ops.cpp create mode 100644 test/polygeist-opt/sql.mlir diff --git a/include/CMakeLists.txt b/include/CMakeLists.txt index da66b9bf293f..f5e589845278 100644 --- a/include/CMakeLists.txt +++ b/include/CMakeLists.txt @@ -1 +1,2 @@ add_subdirectory(polygeist) +add_subdirectory(sql) \ No newline at end of file diff --git a/include/sql/CMakeLists.txt b/include/sql/CMakeLists.txt new file mode 100644 index 000000000000..5de01d8b95a8 --- /dev/null +++ b/include/sql/CMakeLists.txt @@ -0,0 +1,3 @@ +add_mlir_dialect(SQLOps sql) +# add_mlir_doc(SQLDialect -gen-dialect-doc SQLDialect SQL/) +# add_mlir_doc(SQLOps -gen-op-doc SQLOps SQL/) diff --git a/include/sql/SQLDialect.h b/include/sql/SQLDialect.h new file mode 100644 index 000000000000..49d79f5b37dd --- /dev/null +++ b/include/sql/SQLDialect.h @@ -0,0 +1,16 @@ +//===- BFVDialect.h - BFV dialect -----------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef SQL_DIALECT +#define SQL_DIALECT + +#include "mlir/IR/Dialect.h" + +#include "sql/SQLOpsDialect.h.inc" + +#endif diff --git a/include/sql/SQLDialect.td b/include/sql/SQLDialect.td new file mode 100644 index 000000000000..21709a34d694 --- /dev/null +++ b/include/sql/SQLDialect.td @@ -0,0 +1,44 @@ +//===- BFVDialect.td - BFV dialect -----------*- tablegen -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef SQL_DIALECT +#define SQL_DIALECT + + +include "mlir/IR/OpBase.td" +include "mlir/IR/FunctionInterfaces.td" +include "mlir/IR/SymbolInterfaces.td" +include "mlir/Interfaces/SideEffectInterfaces.td" + +include "mlir/Interfaces/CallInterfaces.td" +include "mlir/Interfaces/CastInterfaces.td" + + +def SQL_Dialect : Dialect { + let summary = "A dialect for SQL languages in MLIR."; + let description = [{ + TBD + }]; + let name = "sql"; + let cppNamespace = "::mlir::sql"; +} + + +//===----------------------------------------------------------------------===// +// SQL Operations +//===----------------------------------------------------------------------===// + +class SQL_Op traits = []> + : Op; + +#endif // SQL_DIALECT + + + + + diff --git a/include/sql/SQLOps.h b/include/sql/SQLOps.h new file mode 100644 index 000000000000..131176c222c9 --- /dev/null +++ b/include/sql/SQLOps.h @@ -0,0 +1,23 @@ +//===- Polygeistps.h - Polygeist dialect ops --------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef SQLOPS_H +#define SQLOPS_H + +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Interfaces/ViewLikeInterface.h" +#include "llvm/Support/CommandLine.h" + +#define GET_OP_CLASSES +#include "sql/SQLOps.h.inc" + +#endif \ No newline at end of file diff --git a/include/sql/SQLOps.td b/include/sql/SQLOps.td new file mode 100644 index 000000000000..0e917649e8eb --- /dev/null +++ b/include/sql/SQLOps.td @@ -0,0 +1,57 @@ +//===- SQLOps.td - Polygeist dialect ops ----------------*- tablegen -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +// #ifndef SQL_OPS +// #define SQL_OPS + +include "mlir/IR/AttrTypeBase.td" +include "SQLDialect.td" + +def SelectOp : SQL_Op<"select", [Pure]> { + let summary = "select"; + + // TODO: limit (optional), where clauses, join, etc + let arguments = (ins StrArrayAttr:$column, StrAttr:$table); + let results = (outs Index : $result); + + let hasFolder = 0; + let hasCanonicalizer = 0; +} + +def ExecuteOp : SQL_Op<"execute", []> { + let summary = "execute query"; + + let arguments = (ins Index:$handle); + let results = (outs Index:$result); + + let hasFolder = 0; + let hasCanonicalizer = 0; +} + +def NumResultsOp : SQL_Op<"num_results", [Pure]> { + let summary = "number of results"; + + let arguments = (ins Index:$handle); + let results = (outs Index:$result); + + let hasFolder = 0; + let hasCanonicalizer = 0; +} + +def ResultOp : SQL_Op<"get_result", [Pure]> { + let summary = "get results of execution"; + + let arguments = (ins Index:$handle, StrAttr:$column, Index:$row); + let results = (outs AnyType:$result); + + let hasFolder = 0; + let hasCanonicalizer = 0; +} + + +// #endif \ No newline at end of file diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index da66b9bf293f..f5e589845278 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -1 +1,2 @@ add_subdirectory(polygeist) +add_subdirectory(sql) \ No newline at end of file diff --git a/lib/sql/CMakeLists.txt b/lib/sql/CMakeLists.txt new file mode 100644 index 000000000000..4cbfd895bde3 --- /dev/null +++ b/lib/sql/CMakeLists.txt @@ -0,0 +1,19 @@ +add_mlir_dialect_library(MLIRSQL +Dialect.cpp +Ops.cpp + +ADDITIONAL_HEADER_DIRS +${PROJECT_SOURCE_DIR}/include/sql + +DEPENDS +MLIRSQLOpsIncGen + +LINK_LIBS PUBLIC +MLIRIR +MLIRMemRefDialect +MLIRLLVMDialect +MLIROpenMPDialect +MLIRAffineDialect +MLIRSupport +MLIRSCFTransforms +) diff --git a/lib/sql/Dialect.cpp b/lib/sql/Dialect.cpp new file mode 100644 index 000000000000..f05419f40319 --- /dev/null +++ b/lib/sql/Dialect.cpp @@ -0,0 +1,27 @@ +//===- PolygeistDialect.cpp - Polygeist dialect ---------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "sql/SQLDialect.h" +#include "mlir/IR/DialectImplementation.h" +#include "sql/SQLOps.h" + +using namespace mlir; +using namespace mlir::sql; + +//===----------------------------------------------------------------------===// +// Polygeist dialect. +//===----------------------------------------------------------------------===// + +void SQLDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "sql/SQLOps.cpp.inc" + >(); +} + +#include "sql/SQLOpsDialect.cpp.inc" diff --git a/lib/sql/Ops.cpp b/lib/sql/Ops.cpp new file mode 100644 index 000000000000..8ca225a665da --- /dev/null +++ b/lib/sql/Ops.cpp @@ -0,0 +1,42 @@ +//===- PolygeistOps.cpp - BFV dialect ops ---------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "sql/SQLOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "sql/SQLDialect.h" + +#define GET_OP_CLASSES +#include "sql/SQLOps.cpp.inc" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/OpenMP/OpenMPDialect.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/IntegerSet.h" +#include "mlir/Transforms/SideEffectUtils.h" + +#include "llvm/ADT/SetVector.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "sql" + +using namespace mlir; +using namespace sql; +using namespace mlir::arith; + + diff --git a/test/polygeist-opt/sql.mlir b/test/polygeist-opt/sql.mlir new file mode 100644 index 000000000000..6e7747b0d40e --- /dev/null +++ b/test/polygeist-opt/sql.mlir @@ -0,0 +1,11 @@ +// RUN: polygeist-opt %s | FileCheck %s + +module { + func.func private @run() -> i32 { + %c0 = arith.constant 0 : index + %q = "sql.select"() {column = ["data"], table = "mytable"} : () -> index + %h = "sql.execute"(%q) : (index) -> index + %res = "sql.get_result"(%h, %c0) {column = "data"} : (index, index) -> i32 + return %res : i32 + } +} \ No newline at end of file diff --git a/tools/polygeist-opt/CMakeLists.txt b/tools/polygeist-opt/CMakeLists.txt index ccfebd421d81..c2fd7c41ec4c 100644 --- a/tools/polygeist-opt/CMakeLists.txt +++ b/tools/polygeist-opt/CMakeLists.txt @@ -6,7 +6,6 @@ set(LIBS MLIROptLib MLIRPolygeist MLIRPolygeistTransforms - MLIRFuncAllExtensions ) add_llvm_executable(polygeist-opt polygeist-opt.cpp) diff --git a/tools/polygeist-opt/polygeist-opt.cpp b/tools/polygeist-opt/polygeist-opt.cpp index 95fe1b1fc4a4..34d6f9bad11a 100644 --- a/tools/polygeist-opt/polygeist-opt.cpp +++ b/tools/polygeist-opt/polygeist-opt.cpp @@ -33,6 +33,9 @@ #include "polygeist/Dialect.h" #include "polygeist/Passes/Passes.h" +#include "sql/SQLDialect.h" +#include "sql/SQLOps.h" + using namespace mlir; class MemRefInsider @@ -62,6 +65,7 @@ int main(int argc, char **argv) { registry.insert(); registry.insert(); + registry.insert(); mlir::registerpolygeistPasses(); mlir::func::registerInlinerExtension(registry); @@ -76,6 +80,7 @@ int main(int argc, char **argv) { mlir::registerConvertSCFToOpenMPPass(); mlir::affine::registerAffinePasses(); + registry.addExtension(+[](MLIRContext *ctx, LLVM::LLVMDialect *dialect) { LLVM::LLVMFunctionType::attachInterface(*ctx); }); From 764a5a0270e9158ffd9145b69315e139f69544d9 Mon Sep 17 00:00:00 2001 From: Carl Guo Date: Tue, 31 Jan 2023 21:29:22 -0500 Subject: [PATCH 02/15] reformat --- include/sql/SQLDialect.h | 2 +- lib/sql/Dialect.cpp | 2 +- lib/sql/Ops.cpp | 6 ++---- tools/polygeist-opt/polygeist-opt.cpp | 1 - 4 files changed, 4 insertions(+), 7 deletions(-) diff --git a/include/sql/SQLDialect.h b/include/sql/SQLDialect.h index 49d79f5b37dd..6c5b6c07797d 100644 --- a/include/sql/SQLDialect.h +++ b/include/sql/SQLDialect.h @@ -13,4 +13,4 @@ #include "sql/SQLOpsDialect.h.inc" -#endif +#endif \ No newline at end of file diff --git a/lib/sql/Dialect.cpp b/lib/sql/Dialect.cpp index f05419f40319..372a6e5778e6 100644 --- a/lib/sql/Dialect.cpp +++ b/lib/sql/Dialect.cpp @@ -6,8 +6,8 @@ // //===----------------------------------------------------------------------===// -#include "sql/SQLDialect.h" #include "mlir/IR/DialectImplementation.h" +#include "sql/SQLDialect.h" #include "sql/SQLOps.h" using namespace mlir; diff --git a/lib/sql/Ops.cpp b/lib/sql/Ops.cpp index 8ca225a665da..a7badbedefe8 100644 --- a/lib/sql/Ops.cpp +++ b/lib/sql/Ops.cpp @@ -6,7 +6,6 @@ // //===----------------------------------------------------------------------===// -#include "sql/SQLOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" @@ -15,6 +14,7 @@ #include "mlir/IR/OpImplementation.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "sql/SQLDialect.h" +#include "sql/SQLOps.h" #define GET_OP_CLASSES #include "sql/SQLOps.cpp.inc" @@ -37,6 +37,4 @@ using namespace mlir; using namespace sql; -using namespace mlir::arith; - - +using namespace mlir::arith; \ No newline at end of file diff --git a/tools/polygeist-opt/polygeist-opt.cpp b/tools/polygeist-opt/polygeist-opt.cpp index 34d6f9bad11a..b642cc5aa076 100644 --- a/tools/polygeist-opt/polygeist-opt.cpp +++ b/tools/polygeist-opt/polygeist-opt.cpp @@ -80,7 +80,6 @@ int main(int argc, char **argv) { mlir::registerConvertSCFToOpenMPPass(); mlir::affine::registerAffinePasses(); - registry.addExtension(+[](MLIRContext *ctx, LLVM::LLVMDialect *dialect) { LLVM::LLVMFunctionType::attachInterface(*ctx); }); From 575584f554a528b1aa7301d17db3d3631584ec60 Mon Sep 17 00:00:00 2001 From: Carl Guo Date: Sat, 4 Feb 2023 11:36:15 -0500 Subject: [PATCH 03/15] reformat names; remove unimportant dependencies; add include guards --- include/sql/SQLDialect.h | 2 +- include/sql/SQLDialect.td | 2 +- include/sql/SQLOps.h | 2 +- include/sql/SQLOps.td | 8 ++++---- lib/sql/CMakeLists.txt | 6 ------ 5 files changed, 7 insertions(+), 13 deletions(-) diff --git a/include/sql/SQLDialect.h b/include/sql/SQLDialect.h index 6c5b6c07797d..0ac736f6693e 100644 --- a/include/sql/SQLDialect.h +++ b/include/sql/SQLDialect.h @@ -1,4 +1,4 @@ -//===- BFVDialect.h - BFV dialect -----------------*- C++ -*-===// +//===- SQLDialect.h - SQL dialect -----------------*- C++ -*-===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. diff --git a/include/sql/SQLDialect.td b/include/sql/SQLDialect.td index 21709a34d694..0bd8af6bcb0e 100644 --- a/include/sql/SQLDialect.td +++ b/include/sql/SQLDialect.td @@ -1,4 +1,4 @@ -//===- BFVDialect.td - BFV dialect -----------*- tablegen -*-===// +//===- SQLDialect.td - SQL dialect -----------*- tablegen -*-===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. diff --git a/include/sql/SQLOps.h b/include/sql/SQLOps.h index 131176c222c9..06d0319150bb 100644 --- a/include/sql/SQLOps.h +++ b/include/sql/SQLOps.h @@ -1,4 +1,4 @@ -//===- Polygeistps.h - Polygeist dialect ops --------------------*- C++ -*-===// +//===- SQLOps.h - SQL dialect ops --------------------*- C++ -*-===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. diff --git a/include/sql/SQLOps.td b/include/sql/SQLOps.td index 0e917649e8eb..cfdb3999a87e 100644 --- a/include/sql/SQLOps.td +++ b/include/sql/SQLOps.td @@ -1,4 +1,4 @@ -//===- SQLOps.td - Polygeist dialect ops ----------------*- tablegen -*-===// +//===- SQLOps.td - SQL dialect ops ----------------*- tablegen -*-===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,8 +6,8 @@ // //===----------------------------------------------------------------------===// -// #ifndef SQL_OPS -// #define SQL_OPS +#ifndef SQL_OPS +#define SQL_OPS include "mlir/IR/AttrTypeBase.td" include "SQLDialect.td" @@ -54,4 +54,4 @@ def ResultOp : SQL_Op<"get_result", [Pure]> { } -// #endif \ No newline at end of file +#endif \ No newline at end of file diff --git a/lib/sql/CMakeLists.txt b/lib/sql/CMakeLists.txt index 4cbfd895bde3..e57efadee4d7 100644 --- a/lib/sql/CMakeLists.txt +++ b/lib/sql/CMakeLists.txt @@ -10,10 +10,4 @@ MLIRSQLOpsIncGen LINK_LIBS PUBLIC MLIRIR -MLIRMemRefDialect -MLIRLLVMDialect -MLIROpenMPDialect -MLIRAffineDialect -MLIRSupport -MLIRSCFTransforms ) From 2e7bb44917fa4807404c59437dfd295a6bf19db2 Mon Sep 17 00:00:00 2001 From: Carl Guo Date: Tue, 14 Feb 2023 14:30:38 -0500 Subject: [PATCH 04/15] fix build issue --- lib/sql/Dialect.cpp | 4 ++-- lib/sql/Ops.cpp | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/lib/sql/Dialect.cpp b/lib/sql/Dialect.cpp index 372a6e5778e6..eb42b990bf4a 100644 --- a/lib/sql/Dialect.cpp +++ b/lib/sql/Dialect.cpp @@ -1,4 +1,4 @@ -//===- PolygeistDialect.cpp - Polygeist dialect ---------------*- C++ -*-===// +//===- SQLDialect.cpp - SQL dialect ---------------*- C++ -*-===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -14,7 +14,7 @@ using namespace mlir; using namespace mlir::sql; //===----------------------------------------------------------------------===// -// Polygeist dialect. +// SQL dialect. //===----------------------------------------------------------------------===// void SQLDialect::initialize() { diff --git a/lib/sql/Ops.cpp b/lib/sql/Ops.cpp index a7badbedefe8..3cedd57e62f9 100644 --- a/lib/sql/Ops.cpp +++ b/lib/sql/Ops.cpp @@ -1,4 +1,4 @@ -//===- PolygeistOps.cpp - BFV dialect ops ---------------*- C++ -*-===// +//===- SQLOps.cpp - SQL dialect ops ---------------*- C++ -*-===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. From a324331e142dac7cda3e698dca14d80d0611d82f Mon Sep 17 00:00:00 2001 From: Carl Guo Date: Tue, 14 Feb 2023 14:30:38 -0500 Subject: [PATCH 05/15] fix build issue & cleanup names --- CMakeLists.txt | 46 +++++++++++++++++++++++++++++-------- include/sql/SQLDialect.h | 4 ++-- include/sql/SQLOps.td | 3 +-- lib/sql/CMakeLists.txt | 2 ++ test/polygeist-opt/sql.mlir | 2 +- 5 files changed, 43 insertions(+), 14 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 59e116c6414e..3635652592b6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -2,8 +2,7 @@ cmake_minimum_required(VERSION 3.10) include(CheckCXXSourceCompiles) -set(POLYGEIST_ENABLE_CUDA 0 CACHE BOOL "Enable CUDA frontend and backend") -set(POLYGEIST_ENABLE_ROCM 0 CACHE BOOL "Enable ROCM backend") +set(POLYGEIST_ENABLE_CUDA 0 CACHE BOOL "Enable CUDA compilation support") if(POLICY CMP0068) cmake_policy(SET CMP0068 NEW) @@ -21,16 +20,11 @@ endif() option(LLVM_INCLUDE_TOOLS "Generate build targets for the LLVM tools." ON) option(LLVM_BUILD_TOOLS "Build the LLVM tools. If OFF, just generate build targets." ON) +option(ENABLE_SQL "Build SQL dialect" OFF) + set(LLVM_RUNTIME_OUTPUT_INTDIR ${CMAKE_BINARY_DIR}/bin) set(LLVM_LIBRARY_OUTPUT_INTDIR ${CMAKE_BINARY_DIR}/lib) -find_program(XXD_BIN xxd) - -# TODO should depend on OS -set(POLYGEIST_PGO_DEFAULT_DATA_DIR "/var/tmp/polygeist/pgo/" CACHE STRING "Directory for PGO data") -set(POLYGEIST_PGO_ALTERNATIVE_ENV_VAR "POLYGEIST_PGO_ALTERNATIVE" CACHE STRING "Env var name to specify alternative to profile") -set(POLYGEIST_PGO_DATA_DIR_ENV_VAR "POLYGEIST_PGO_DATA_DIR" CACHE STRING "Env var name to specify PGO data dir") - if (CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR) project(polygeist LANGUAGES CXX C) @@ -112,6 +106,40 @@ set(LLVM_LIT_ARGS "-sv" CACHE STRING "lit default options") list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake/modules") include(sanitizers) +if (ENABLE_SQL) + include(FetchContent) + include(ExternalProject) + + FetchContent_Declare(sqlparser_ext + GIT_REPOSITORY https://github.com/wsmoses/sql-parser + GIT_TAG c2471248cef8cd33081e698e8ac65d691283dbd4 + ) + + FetchContent_GetProperties(sqlparser_ext) + + FetchContent_MakeAvailable(sqlparser_ext) + + ExternalProject_Add(sqlparser + SOURCE_DIR ${sqlparser_ext_SOURCE_DIR} + INSTALL_DIR ${CMAKE_CURRENT_BINARY_DIR}/sql/install + CONFIGURE_COMMAND "" + BUILD_COMMAND ${CMAKE_COMMAND} -E env + CXX=${CMAKE_CXX_COMPILER} + make static=yes -C ${sqlparser_ext_SOURCE_DIR} + BUILD_IN_SOURCE TRUE + INSTALL_COMMAND "" + BUILD_BYPRODUCTS ${sqlparser_ext_SOURCE_DIR}/libsqlparser.a + ) + + + add_library(sqlparse_lib INTERFACE) + + target_include_directories(sqlparse_lib INTERFACE "${sqlparser_ext_SOURCE_DIR}/src") + target_link_libraries(sqlparse_lib INTERFACE ${sqlparser_ext_SOURCE_DIR}/libsqlparser.a) + add_dependencies(sqlparse_lib sqlparser) + +endif() + add_subdirectory(include) add_subdirectory(lib) add_subdirectory(tools) diff --git a/include/sql/SQLDialect.h b/include/sql/SQLDialect.h index 0ac736f6693e..bdb0178a0662 100644 --- a/include/sql/SQLDialect.h +++ b/include/sql/SQLDialect.h @@ -6,8 +6,8 @@ // //===----------------------------------------------------------------------===// -#ifndef SQL_DIALECT -#define SQL_DIALECT +#ifndef SQL_DIALECT_H +#define SQL_DIALECT_H #include "mlir/IR/Dialect.h" diff --git a/include/sql/SQLOps.td b/include/sql/SQLOps.td index cfdb3999a87e..1acbcfeb6374 100644 --- a/include/sql/SQLOps.td +++ b/include/sql/SQLOps.td @@ -53,5 +53,4 @@ def ResultOp : SQL_Op<"get_result", [Pure]> { let hasCanonicalizer = 0; } - -#endif \ No newline at end of file +#endif // SQL_OPS \ No newline at end of file diff --git a/lib/sql/CMakeLists.txt b/lib/sql/CMakeLists.txt index e57efadee4d7..cd4bcfde1856 100644 --- a/lib/sql/CMakeLists.txt +++ b/lib/sql/CMakeLists.txt @@ -11,3 +11,5 @@ MLIRSQLOpsIncGen LINK_LIBS PUBLIC MLIRIR ) + +add_subdirectory(Passes) \ No newline at end of file diff --git a/test/polygeist-opt/sql.mlir b/test/polygeist-opt/sql.mlir index 6e7747b0d40e..024215c97895 100644 --- a/test/polygeist-opt/sql.mlir +++ b/test/polygeist-opt/sql.mlir @@ -1,5 +1,5 @@ // RUN: polygeist-opt %s | FileCheck %s - +// -lower-sql module { func.func private @run() -> i32 { %c0 = arith.constant 0 : index From 7dae0f759fbb0629df9bef0408d44a48a6455fae Mon Sep 17 00:00:00 2001 From: Carl Guo Date: Fri, 3 Mar 2023 13:53:27 -0500 Subject: [PATCH 06/15] SQLLower --- include/sql/CMakeLists.txt | 2 + include/sql/Passes/CMakeLists.txt | 5 ++ include/sql/Passes/Passes.h | 48 +++++++++++ include/sql/Passes/Passes.td | 14 +++ include/sql/Passes/Utils.h | 139 ++++++++++++++++++++++++++++++ lib/sql/Passes/CMakeLists.txt | 17 ++++ lib/sql/Passes/SQLLower.cpp | 120 ++++++++++++++++++++++++++ 7 files changed, 345 insertions(+) create mode 100644 include/sql/Passes/CMakeLists.txt create mode 100644 include/sql/Passes/Passes.h create mode 100644 include/sql/Passes/Passes.td create mode 100644 include/sql/Passes/Utils.h create mode 100644 lib/sql/Passes/CMakeLists.txt create mode 100644 lib/sql/Passes/SQLLower.cpp diff --git a/include/sql/CMakeLists.txt b/include/sql/CMakeLists.txt index 5de01d8b95a8..8764d8eddbb4 100644 --- a/include/sql/CMakeLists.txt +++ b/include/sql/CMakeLists.txt @@ -1,3 +1,5 @@ add_mlir_dialect(SQLOps sql) # add_mlir_doc(SQLDialect -gen-dialect-doc SQLDialect SQL/) # add_mlir_doc(SQLOps -gen-op-doc SQLOps SQL/) + +add_subdirectory(Passes) \ No newline at end of file diff --git a/include/sql/Passes/CMakeLists.txt b/include/sql/Passes/CMakeLists.txt new file mode 100644 index 000000000000..298ddabf48f7 --- /dev/null +++ b/include/sql/Passes/CMakeLists.txt @@ -0,0 +1,5 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name sql) +add_public_tablegen_target(MLIRSQLPassIncGen) + +add_mlir_doc(Passes SQLPasses ./ -gen-pass-doc) diff --git a/include/sql/Passes/Passes.h b/include/sql/Passes/Passes.h new file mode 100644 index 000000000000..3c0861eeab4e --- /dev/null +++ b/include/sql/Passes/Passes.h @@ -0,0 +1,48 @@ +#ifndef SQL_DIALECT_SQL_PASSES_H +#define SQL_DIALECT_SQL_PASSES_H + +#include "mlir/Conversion/LLVMCommon/LoweringOptions.h" +#include "mlir/Pass/Pass.h" +#include +namespace mlir { +class PatternRewriter; +class RewritePatternSet; +class DominanceInfo; +namespace sql { + +std::unique_ptr createParallelLowerPass(); +} // namespace sql +} // namespace mlir + +namespace mlir { +// Forward declaration from Dialect.h +template +void registerDialect(DialectRegistry ®istry); + +namespace arith { +class ArithDialect; +} // end namespace arith + +namespace scf { +class SCFDialect; +} // end namespace scf + +namespace memref { +class MemRefDialect; +} // end namespace memref + +namespace func { +class FuncDialect; +} + +class AffineDialect; +namespace LLVM { +class LLVMDialect; +} + +#define GEN_PASS_REGISTRATION +#include "sql/Passes/Passes.h.inc" + +} // end namespace mlir + +#endif // SQL_DIALECT_SQL_PASSES_H diff --git a/include/sql/Passes/Passes.td b/include/sql/Passes/Passes.td new file mode 100644 index 000000000000..36faaab4a52f --- /dev/null +++ b/include/sql/Passes/Passes.td @@ -0,0 +1,14 @@ +#ifndef SQL_PASSES +#define SQL_PASSES + +include "mlir/Pass/PassBase.td" + + +def ParallelLower : Pass<"sql-lower", "mlir::ModuleOp"> { + let summary = "Lower sql op to mlir"; + let dependentDialects = + ["arith::AirthDialect", "func::FuncDialect", "LLVM::LLVMDialect"]; + let constructor = "mlir::sql::createSQLLowerPass()"; +} + +#endif // SQL_PASSES diff --git a/include/sql/Passes/Utils.h b/include/sql/Passes/Utils.h new file mode 100644 index 000000000000..b191d89ba6a6 --- /dev/null +++ b/include/sql/Passes/Utils.h @@ -0,0 +1,139 @@ +#pragma once + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/IntegerSet.h" + +static inline mlir::scf::IfOp +cloneWithResults(mlir::scf::IfOp op, mlir::OpBuilder &rewriter, + mlir::BlockAndValueMapping mapping = {}) { + using namespace mlir; + return rewriter.create(op.getLoc(), op.getResultTypes(), + mapping.lookupOrDefault(op.getCondition()), + true); +} +static inline mlir::AffineIfOp +cloneWithResults(mlir::AffineIfOp op, mlir::OpBuilder &rewriter, + mlir::BlockAndValueMapping mapping = {}) { + using namespace mlir; + SmallVector lower; + for (auto o : op.getOperands()) + lower.push_back(mapping.lookupOrDefault(o)); + return rewriter.create(op.getLoc(), op.getResultTypes(), + op.getIntegerSet(), lower, true); +} + +static inline mlir::scf::IfOp +cloneWithoutResults(mlir::scf::IfOp op, mlir::OpBuilder &rewriter, + mlir::BlockAndValueMapping mapping = {}, + mlir::TypeRange types = {}) { + using namespace mlir; + return rewriter.create( + op.getLoc(), types, mapping.lookupOrDefault(op.getCondition()), true); +} +static inline mlir::AffineIfOp +cloneWithoutResults(mlir::AffineIfOp op, mlir::OpBuilder &rewriter, + mlir::BlockAndValueMapping mapping = {}, + mlir::TypeRange types = {}) { + using namespace mlir; + SmallVector lower; + for (auto o : op.getOperands()) + lower.push_back(mapping.lookupOrDefault(o)); + return rewriter.create(op.getLoc(), types, op.getIntegerSet(), + lower, true); +} + +static inline mlir::scf::ForOp +cloneWithoutResults(mlir::scf::ForOp op, mlir::PatternRewriter &rewriter, + mlir::BlockAndValueMapping mapping = {}) { + using namespace mlir; + return rewriter.create( + op.getLoc(), mapping.lookupOrDefault(op.getLowerBound()), + mapping.lookupOrDefault(op.getUpperBound()), + mapping.lookupOrDefault(op.getStep())); +} +static inline mlir::AffineForOp +cloneWithoutResults(mlir::AffineForOp op, mlir::PatternRewriter &rewriter, + mlir::BlockAndValueMapping mapping = {}) { + using namespace mlir; + SmallVector lower; + for (auto o : op.getLowerBoundOperands()) + lower.push_back(mapping.lookupOrDefault(o)); + SmallVector upper; + for (auto o : op.getUpperBoundOperands()) + upper.push_back(mapping.lookupOrDefault(o)); + return rewriter.create(op.getLoc(), lower, op.getLowerBoundMap(), + upper, op.getUpperBoundMap(), + op.getStep()); +} + +static inline void clearBlock(mlir::Block *block, + mlir::PatternRewriter &rewriter) { + for (auto &op : llvm::make_early_inc_range(llvm::reverse(*block))) { + assert(op.use_empty() && "expected 'op' to have no uses"); + rewriter.eraseOp(&op); + } +} + +static inline mlir::Block *getThenBlock(mlir::scf::IfOp op) { + return op.thenBlock(); +} +static inline mlir::Block *getThenBlock(mlir::AffineIfOp op) { + return op.getThenBlock(); +} +static inline mlir::Block *getElseBlock(mlir::scf::IfOp op) { + return op.elseBlock(); +} +static inline mlir::Block *getElseBlock(mlir::AffineIfOp op) { + if (op.hasElse()) + return op.getElseBlock(); + else + return nullptr; +} + +static inline mlir::Region &getThenRegion(mlir::scf::IfOp op) { + return op.getThenRegion(); +} +static inline mlir::Region &getThenRegion(mlir::AffineIfOp op) { + return op.getThenRegion(); +} +static inline mlir::Region &getElseRegion(mlir::scf::IfOp op) { + return op.getElseRegion(); +} +static inline mlir::Region &getElseRegion(mlir::AffineIfOp op) { + return op.getElseRegion(); +} + +static inline mlir::scf::YieldOp getThenYield(mlir::scf::IfOp op) { + return op.thenYield(); +} +static inline mlir::AffineYieldOp getThenYield(mlir::AffineIfOp op) { + return llvm::cast(op.getThenBlock()->getTerminator()); +} +static inline mlir::scf::YieldOp getElseYield(mlir::scf::IfOp op) { + return op.elseYield(); +} +static inline mlir::AffineYieldOp getElseYield(mlir::AffineIfOp op) { + return llvm::cast(op.getElseBlock()->getTerminator()); +} + +static inline bool inBound(mlir::scf::IfOp op, mlir::Value v) { + return op.getCondition() == v; +} +static inline bool inBound(mlir::AffineIfOp op, mlir::Value v) { + return llvm::any_of(op.getOperands(), [&](mlir::Value e) { return e == v; }); +} +static inline bool inBound(mlir::scf::ForOp op, mlir::Value v) { + return op.getUpperBound() == v; +} +static inline bool inBound(mlir::AffineForOp op, mlir::Value v) { + return llvm::any_of(op.getUpperBoundOperands(), + [&](mlir::Value e) { return e == v; }); +} +static inline bool hasElse(mlir::scf::IfOp op) { + return op.getElseRegion().getBlocks().size() > 0; +} +static inline bool hasElse(mlir::AffineIfOp op) { + return op.getElseRegion().getBlocks().size() > 0; +} diff --git a/lib/sql/Passes/CMakeLists.txt b/lib/sql/Passes/CMakeLists.txt new file mode 100644 index 000000000000..e2b4f46f658f --- /dev/null +++ b/lib/sql/Passes/CMakeLists.txt @@ -0,0 +1,17 @@ +add_mlir_dialect_library(MLIRSQLTransforms + SQLLower.cpp + + DEPENDS + MLIRPolygeistOpsIncGen + MLIRPolygeistPassIncGen + + LINK_LIBS PUBLIC + MLIRArithDialect + MLIRFuncDialect + MLIRFuncTransforms + MLIRIR + MLIRLLVMDialect + MLIRMathDialect + MLIRMemRefDialect + MLIRPass + ) \ No newline at end of file diff --git a/lib/sql/Passes/SQLLower.cpp b/lib/sql/Passes/SQLLower.cpp new file mode 100644 index 000000000000..6daaba66b6b1 --- /dev/null +++ b/lib/sql/Passes/SQLLower.cpp @@ -0,0 +1,120 @@ +//===- SQLLower.cpp - Lower sql ops to mlir ------ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to lower gpu kernels in NVVM/gpu dialects into +// a generic SQL for representation +//===----------------------------------------------------------------------===// + +#include "PassDetails.h" +#include "mlir/Analysis/CallGraph.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Async/IR/Async.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallPtrSet.h" +#include +#include + +#define DEBUG_TYPE "sql-opt" + +using namespace mlir; +using namespace mlir::arith; +using namespace mlir::func; +using namespace sql; + +namespace { +struct SQLLower : public SQLLowerBase { + void runOnOperation() override; +}; + +} // end anonymous namespace + +struct NumResultsOpLowering : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(sql::NumResultsOp loop, + PatternRewriter &rewriter) const final { + auto module = loop->getParentOfType(); + + // 1) make sure the postgres_getresult function is declared + auto rowsfn = dyn_cast_or_null(symbolTable.lookupSymbolIn( + module, builder.getStringAttr("PQcmdTuples"))); + + auto atoifn = dyn_cast_or_null( + symbolTable.lookupSymbolIn(module, builder.getStringAttr("atoi"))); + + // 2) convert the args to valid args to postgres_getresult abi + Value arg = loop.getHandle(); + arg = rewriter.create(loop.getLoc(), + rewriter.getIntTy(64), arg); + arg = rewriter.create( + loop.getLoc(), LLVM::LLVMPointerType::get(builder.getInt8Ty()), arg); + + // 3) call and replace + Value args[] = {arg} Value res = + rewriter.create(loop.getLoc(), rowsfn, args) + ->getResult(0); + + Value args2[] = {res} Value res2 = + rewriter.create(loop.getLoc(), atoifn, args2) + ->getResult(0); + + rewriter.replaceOpWithNewOp( + loop, rewriter.getIndexType(), res2); + + // 4) done + return success(); + } +}; + +void SQLLower::runOnOperation() { + auto module = getOperation(); + OpBuilder builder(module.getContext()); + builder.setInsertionPointToStart(module.getBody()); + + if (!dyn_cast_or_null(symbolTable.lookupSymbolIn( + module, builder.getStringAttr("PQcmdTuples")))) { + mlir::Type argtypes[] = {LLVM::LLVMPointerType::get(builder.getInt8Ty())}; + mlir::Type rettypes[] = {LLVM::LLVMPointerType::get(builder.getInt8Ty())}; + + auto fn = + builder.create(module.getLoc(), "PQcmdTuples", + builder.getFunctionType(argtys, rettys)); + SymbolTable::setSymbolVisibility(fn, SymbolTable::Private); + } + if (!dyn_cast_or_null( + symbolTable.lookupSymbolIn(module, builder.getStringAttr("atoi")))) { + mlir::Type argtypes[] = {LLVM::LLVMPointerType::get(builder.getInt8Ty())}; + + // todo use data layout + mlir::Type rettypes[] = {builder.getIntTy(sizeof(int))}; + + auto fn = builder.create( + module.getLoc(), "atoi", builder.getFunctionType(argtys, rettys)); + SymbolTable::setSymbolVisibility(fn, SymbolTable::Private); + } + + RewritePatternSet patterns(&getContext()); + patterns.insert(&getContext()); + + GreedyRewriteConfig config; + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), + config); +} + +namespace mlir { +namespace polygeist { +std::unique_ptr createSQLLowerPass() { + return std::make_unique(); +} +} // namespace polygeist +} // namespace mlir From 32bc7d3b57b336251fb793ee217c4e6e0465c8a0 Mon Sep 17 00:00:00 2001 From: Carl Guo Date: Tue, 16 May 2023 20:32:46 -0400 Subject: [PATCH 07/15] sql lowering workings with pragma --- include/sql/Passes/Passes.h | 5 +- include/sql/Passes/Passes.td | 12 +++- lib/sql/Passes/CMakeLists.txt | 2 + lib/sql/Passes/PassDetails.h | 39 ++++++++++++ lib/sql/Passes/SQLLower.cpp | 50 +++++++++------ lib/sql/Passes/SQLRaising.cpp | 113 ++++++++++++++++++++++++++++++++++ test_with_pragma.c | 45 ++++++++++++++ 7 files changed, 244 insertions(+), 22 deletions(-) create mode 100644 lib/sql/Passes/PassDetails.h create mode 100644 lib/sql/Passes/SQLRaising.cpp create mode 100644 test_with_pragma.c diff --git a/include/sql/Passes/Passes.h b/include/sql/Passes/Passes.h index 3c0861eeab4e..82bc1667b8dd 100644 --- a/include/sql/Passes/Passes.h +++ b/include/sql/Passes/Passes.h @@ -10,10 +10,13 @@ class RewritePatternSet; class DominanceInfo; namespace sql { -std::unique_ptr createParallelLowerPass(); +std::unique_ptr createSQLLowerPass(); +std::unique_ptr createSQLRaisingPass(); } // namespace sql } // namespace mlir + + namespace mlir { // Forward declaration from Dialect.h template diff --git a/include/sql/Passes/Passes.td b/include/sql/Passes/Passes.td index 36faaab4a52f..11968e1415fe 100644 --- a/include/sql/Passes/Passes.td +++ b/include/sql/Passes/Passes.td @@ -4,11 +4,19 @@ include "mlir/Pass/PassBase.td" -def ParallelLower : Pass<"sql-lower", "mlir::ModuleOp"> { +def SQLLower : Pass<"sql-lower", "mlir::ModuleOp"> { let summary = "Lower sql op to mlir"; let dependentDialects = - ["arith::AirthDialect", "func::FuncDialect", "LLVM::LLVMDialect"]; + ["arith::ArithDialect", "func::FuncDialect", "LLVM::LLVMDialect"]; let constructor = "mlir::sql::createSQLLowerPass()"; } + +def SQLRaising : Pass<"sql-raising", "mlir::ModuleOp"> { + let summary = "Raise sql op to mlir"; + let dependentDialects = + ["arith::ArithDialect", "func::FuncDialect", "LLVM::LLVMDialect"]; + let constructor = "mlir::sql::createSQLRaisingPass()"; +} + #endif // SQL_PASSES diff --git a/lib/sql/Passes/CMakeLists.txt b/lib/sql/Passes/CMakeLists.txt index e2b4f46f658f..ee492cf8f26c 100644 --- a/lib/sql/Passes/CMakeLists.txt +++ b/lib/sql/Passes/CMakeLists.txt @@ -1,9 +1,11 @@ add_mlir_dialect_library(MLIRSQLTransforms SQLLower.cpp + SQLRaising.cpp DEPENDS MLIRPolygeistOpsIncGen MLIRPolygeistPassIncGen + MLIRSQLPassIncGen LINK_LIBS PUBLIC MLIRArithDialect diff --git a/lib/sql/Passes/PassDetails.h b/lib/sql/Passes/PassDetails.h new file mode 100644 index 000000000000..1436f9e69809 --- /dev/null +++ b/lib/sql/Passes/PassDetails.h @@ -0,0 +1,39 @@ +//===- PassDetails.h - polygeist pass class details ----------------*- C++ +//-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Stuff shared between the different polygeist passes. +// +//===----------------------------------------------------------------------===// + +// clang-tidy seems to expect the absolute path in the header guard on some +// systems, so just disable it. +// NOLINTNEXTLINE(llvm-header-guard) +#ifndef DIALECT_POLYGEIST_TRANSFORMS_PASSDETAILS_H +#define DIALECT_POLYGEIST_TRANSFORMS_PASSDETAILS_H + +#include "mlir/Pass/Pass.h" +#include "sql/SQLOps.h" +#include "sql/Passes/Passes.h" + +namespace mlir { +class FunctionOpInterface; +// Forward declaration from Dialect.h +template +void registerDialect(DialectRegistry ®istry); +namespace sql { + +class SQLDialect; + +#define GEN_PASS_CLASSES +#include "sql/Passes/Passes.h.inc" + +} // namespace polygeist +} // namespace mlir + +#endif // DIALECT_POLYGEIST_TRANSFORMS_PASSDETAILS_H diff --git a/lib/sql/Passes/SQLLower.cpp b/lib/sql/Passes/SQLLower.cpp index 6daaba66b6b1..3fe7c831398f 100644 --- a/lib/sql/Passes/SQLLower.cpp +++ b/lib/sql/Passes/SQLLower.cpp @@ -1,4 +1,4 @@ -//===- SQLLower.cpp - Lower sql ops to mlir ------ -*-===// +//===- SQLLower.cpp - Lower PostgreSQL to sql mlir ops ------ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -21,15 +21,17 @@ #include "mlir/Transforms/Passes.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallPtrSet.h" +#include "sql/SQLOps.h" +#include "sql/Passes/Passes.h" #include #include -#define DEBUG_TYPE "sql-opt" +#define DEBUG_TYPE "sql-lower-opt" using namespace mlir; using namespace mlir::arith; using namespace mlir::func; -using namespace sql; +using namespace mlir::sql; namespace { struct SQLLower : public SQLLowerBase { @@ -45,26 +47,33 @@ struct NumResultsOpLowering : public OpRewritePattern { PatternRewriter &rewriter) const final { auto module = loop->getParentOfType(); + SymbolTableCollection symbolTable; + symbolTable.getSymbolTable(loop); + // 1) make sure the postgres_getresult function is declared - auto rowsfn = dyn_cast_or_null(symbolTable.lookupSymbolIn( - module, builder.getStringAttr("PQcmdTuples"))); + auto rowsfn = dyn_cast_or_null( + symbolTable.lookupSymbolIn(module, rewriter.getStringAttr("PQcmdTuples"))); auto atoifn = dyn_cast_or_null( - symbolTable.lookupSymbolIn(module, builder.getStringAttr("atoi"))); + symbolTable.lookupSymbolIn(module, rewriter.getStringAttr("atoi"))); // 2) convert the args to valid args to postgres_getresult abi Value arg = loop.getHandle(); arg = rewriter.create(loop.getLoc(), - rewriter.getIntTy(64), arg); + rewriter.getI64Type(), arg); arg = rewriter.create( - loop.getLoc(), LLVM::LLVMPointerType::get(builder.getInt8Ty()), arg); + loop.getLoc(), LLVM::LLVMPointerType::get(rewriter.getI8Type()), arg); // 3) call and replace - Value args[] = {arg} Value res = + Value args[] = {arg}; + + Value res = rewriter.create(loop.getLoc(), rowsfn, args) ->getResult(0); - Value args2[] = {res} Value res2 = + Value args2[] = {res}; + + Value res2 = rewriter.create(loop.getLoc(), atoifn, args2) ->getResult(0); @@ -78,29 +87,32 @@ struct NumResultsOpLowering : public OpRewritePattern { void SQLLower::runOnOperation() { auto module = getOperation(); + + SymbolTableCollection symbolTable; + symbolTable.getSymbolTable(module); OpBuilder builder(module.getContext()); builder.setInsertionPointToStart(module.getBody()); if (!dyn_cast_or_null(symbolTable.lookupSymbolIn( module, builder.getStringAttr("PQcmdTuples")))) { - mlir::Type argtypes[] = {LLVM::LLVMPointerType::get(builder.getInt8Ty())}; - mlir::Type rettypes[] = {LLVM::LLVMPointerType::get(builder.getInt8Ty())}; + mlir::Type argtypes[] = {LLVM::LLVMPointerType::get(builder.getI8Type())}; + mlir::Type rettypes[] = {LLVM::LLVMPointerType::get(builder.getI8Type())}; auto fn = builder.create(module.getLoc(), "PQcmdTuples", - builder.getFunctionType(argtys, rettys)); - SymbolTable::setSymbolVisibility(fn, SymbolTable::Private); + builder.getFunctionType(argtypes, rettypes)); + SymbolTable::setSymbolVisibility(fn, SymbolTable::Visibility::Private); } if (!dyn_cast_or_null( symbolTable.lookupSymbolIn(module, builder.getStringAttr("atoi")))) { - mlir::Type argtypes[] = {LLVM::LLVMPointerType::get(builder.getInt8Ty())}; + mlir::Type argtypes[] = {LLVM::LLVMPointerType::get(builder.getI8Type())}; // todo use data layout - mlir::Type rettypes[] = {builder.getIntTy(sizeof(int))}; + mlir::Type rettypes[] = {builder.getI64Type()}; auto fn = builder.create( - module.getLoc(), "atoi", builder.getFunctionType(argtys, rettys)); - SymbolTable::setSymbolVisibility(fn, SymbolTable::Private); + module.getLoc(), "atoi", builder.getFunctionType(argtypes, rettypes)); + SymbolTable::setSymbolVisibility(fn, SymbolTable::Visibility::Private); } RewritePatternSet patterns(&getContext()); @@ -112,7 +124,7 @@ void SQLLower::runOnOperation() { } namespace mlir { -namespace polygeist { +namespace sql { std::unique_ptr createSQLLowerPass() { return std::make_unique(); } diff --git a/lib/sql/Passes/SQLRaising.cpp b/lib/sql/Passes/SQLRaising.cpp new file mode 100644 index 000000000000..c46270072010 --- /dev/null +++ b/lib/sql/Passes/SQLRaising.cpp @@ -0,0 +1,113 @@ +//===- SQLLower.cpp - Lower PostgreSQL to sql mlir ops ------ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to lower gpu kernels in NVVM/gpu dialects into +// a generic SQL for representation +//===----------------------------------------------------------------------===// + +#include "PassDetails.h" +#include "mlir/Analysis/CallGraph.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Async/IR/Async.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "sql/SQLOps.h" +#include "sql/Passes/Passes.h" +#include +#include + +#define DEBUG_TYPE "sql-raising-opt" + +using namespace mlir; +using namespace mlir::arith; +using namespace mlir::func; +using namespace mlir::sql; + +namespace { +struct SQLRaising : public SQLRaisingBase { + void runOnOperation() override; +}; + +} // end anonymous namespace + +struct PQcmdTuplesRaising : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(func::CallOp call, + PatternRewriter &rewriter) const final { + if (call.getCallee() != "PQcmdTuples") { + return failure(); + } + SymbolTableCollection symbolTable; + symbolTable.getSymbolTable(call); + auto module = call->getParentOfType(); + + // 2) convert the args to valid args to postgres_getresult abi + Value arg = call.getArgOperands()[0]; + arg = rewriter.create( + call.getLoc(), rewriter.getIntegerType(64), arg); + + arg = rewriter.create(call.getLoc(), + rewriter.getIndexType(), arg); + + Value res = rewriter.create(call.getLoc(), rewriter.getIndexType(), arg); + + res = rewriter.create(call.getLoc(), + rewriter.getI64Type(), res); + + auto itoafn = dyn_cast_or_null( + symbolTable.lookupSymbolIn(module, rewriter.getStringAttr("itoa"))); + + Value args2[] = {res}; + + rewriter.replaceOpWithNewOp(call, itoafn, args2); + + // 4) done + return success(); + } +}; + +void SQLRaising::runOnOperation() { + auto module = getOperation(); + SymbolTableCollection symbolTable; + symbolTable.getSymbolTable(module); + OpBuilder builder(module.getContext()); + builder.setInsertionPointToStart(module.getBody()); + + if (!dyn_cast_or_null( + symbolTable.lookupSymbolIn(module, builder.getStringAttr("itoa")))) { + mlir::Type rettypes[] = {LLVM::LLVMPointerType::get(builder.getI8Type())}; + + // todo use data layout + mlir::Type argtypes[] = {builder.getI64Type()}; + + auto fn = builder.create( + module.getLoc(), "itoa", builder.getFunctionType(argtypes, rettypes)); + SymbolTable::setSymbolVisibility(fn, SymbolTable::Visibility::Private); + } + + RewritePatternSet patterns(&getContext()); + patterns.insert(&getContext()); + + GreedyRewriteConfig config; + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), + config); +} + +namespace mlir { +namespace sql { +std::unique_ptr createSQLRaisingPass() { + return std::make_unique(); +} +} // namespace polygeist +} // namespace mlir diff --git a/test_with_pragma.c b/test_with_pragma.c new file mode 100644 index 000000000000..2a852b730804 --- /dev/null +++ b/test_with_pragma.c @@ -0,0 +1,45 @@ +#include +#include +#include + +// PGresult *PQexec(PGconn*, const char* command); +// PQgetvalue +// %7 = call @PQexec(%2, %6) : (memref, memref) -> memref +#pragma lower_to(num_rows_fn, "sql.num_results") +int num_rows_fn(size_t);// char* + +void do_exit(PGconn *conn) { + + PQfinish(conn); + exit(1); +} + +int main() { + + PGconn *conn = PQconnectdb("user=janbodnar dbname=testdb"); + + if (PQstatus(conn) == CONNECTION_BAD) { + + fprintf(stderr, "Connection to database failed: %s\n", + PQerrorMessage(conn)); + do_exit(conn); + } + + PGresult *res = PQexec(conn, "SELECT VERSION()"); + + if (PQresultStatus(res) != PGRES_TUPLES_OK) { + + printf("No data retrieved\n"); + PQclear(res); + do_exit(conn); + } + + printf("%s\n", PQgetvalue(res, 0, 0)); + printf("%d\n", num_rows_fn((size_t)res)); + // res, 0, 0)); + + PQclear(res); + PQfinish(conn); + + return 0; +} From 6634ed0cf467b7aeba6bdceeddac6ed00db2f3cc Mon Sep 17 00:00:00 2001 From: Carl Guo Date: Tue, 6 Jun 2023 19:27:47 -0400 Subject: [PATCH 08/15] mlir gen works; lowering still buggy --- include/sql/SQLOps.td | 12 +-- lib/sql/Ops.cpp | 82 ++++++++++++++- lib/sql/Passes/SQLLower.cpp | 137 +++++++++++++++++++++++++- lib/sql/Passes/SQLRaising.cpp | 86 ++++++++++++++++ test_with_pragma.c | 9 +- tools/cgeist/CMakeLists.txt | 3 + tools/cgeist/Lib/clang-mlir.cc | 4 + tools/cgeist/driver.cc | 3 + tools/polygeist-opt/polygeist-opt.cpp | 2 + 9 files changed, 327 insertions(+), 11 deletions(-) diff --git a/include/sql/SQLOps.td b/include/sql/SQLOps.td index 1acbcfeb6374..002c8444849c 100644 --- a/include/sql/SQLOps.td +++ b/include/sql/SQLOps.td @@ -26,7 +26,7 @@ def SelectOp : SQL_Op<"select", [Pure]> { def ExecuteOp : SQL_Op<"execute", []> { let summary = "execute query"; - let arguments = (ins Index:$handle); + let arguments = (ins Index:$conn, Index:$command); let results = (outs Index:$result); let hasFolder = 0; @@ -40,17 +40,17 @@ def NumResultsOp : SQL_Op<"num_results", [Pure]> { let results = (outs Index:$result); let hasFolder = 0; - let hasCanonicalizer = 0; + let hasCanonicalizer = 1; } -def ResultOp : SQL_Op<"get_result", [Pure]> { - let summary = "get results of execution"; +def GetValueOp : SQL_Op<"get_value", [Pure]> { + let summary = "get value of execution"; - let arguments = (ins Index:$handle, StrAttr:$column, Index:$row); + let arguments = (ins Index:$handle, Index:$column, Index:$row); let results = (outs AnyType:$result); let hasFolder = 0; - let hasCanonicalizer = 0; + let hasCanonicalizer = 1; } #endif // SQL_OPS \ No newline at end of file diff --git a/lib/sql/Ops.cpp b/lib/sql/Ops.cpp index 3cedd57e62f9..332c0e61075a 100644 --- a/lib/sql/Ops.cpp +++ b/lib/sql/Ops.cpp @@ -37,4 +37,84 @@ using namespace mlir; using namespace sql; -using namespace mlir::arith; \ No newline at end of file +using namespace mlir::arith; + + +class GetValueOpTypeFix final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(GetValueOp op, + PatternRewriter &rewriter) const override { + + bool changed = false; + + Value handle = op.getOperand(0); + + if (!handle.getType().isa()) { + handle = rewriter.create(op.getLoc(), + rewriter.getIndexType(), handle); + changed = true; + } + Value row = op.getOperand(1); + if (!row.getType().isa()) { + row = rewriter.create(op.getLoc(), + rewriter.getIndexType(), row); + changed = true; + } + Value column = op.getOperand(2); + if (!column.getType().isa()) { + column = rewriter.create(op.getLoc(), + rewriter.getIndexType(), column); + changed = true; + } + + if (!changed) return failure(); + + rewriter.replaceOpWithNewOp(op, op.getType(), handle, row, column); + + return success(changed); + } +}; + +void GetValueOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.insert(context); +} + + + +class NumResultsOpTypeFix final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(NumResultsOp op, + PatternRewriter &rewriter) const override { + bool changed = false; + Value handle = op->getOperand(0); + + if (handle.getType().isa() && op->getResultTypes()[0].isa()) + return failure(); + + if (!handle.getType().isa()) { + handle = rewriter.create(op.getLoc(), + rewriter.getIndexType(), handle); + changed = true; + } + + mlir::Value res = rewriter.create(op.getLoc(), rewriter.getIndexType(), handle); + + if (op->getResultTypes()[0].isa()) { + rewriter.replaceOp(op, res); + } else { + rewriter.replaceOpWithNewOp(op, op->getResultTypes()[0], res); + } + + return success(changed); + } +}; + +void NumResultsOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.insert(context); +} \ No newline at end of file diff --git a/lib/sql/Passes/SQLLower.cpp b/lib/sql/Passes/SQLLower.cpp index 3fe7c831398f..8637ce277a42 100644 --- a/lib/sql/Passes/SQLLower.cpp +++ b/lib/sql/Passes/SQLLower.cpp @@ -48,7 +48,7 @@ struct NumResultsOpLowering : public OpRewritePattern { auto module = loop->getParentOfType(); SymbolTableCollection symbolTable; - symbolTable.getSymbolTable(loop); + symbolTable.getSymbolTable(module); // 1) make sure the postgres_getresult function is declared auto rowsfn = dyn_cast_or_null( @@ -80,7 +80,110 @@ struct NumResultsOpLowering : public OpRewritePattern { rewriter.replaceOpWithNewOp( loop, rewriter.getIndexType(), res2); - // 4) done + return success(); + } +}; + + +struct GetValueOpLowering : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(sql::GetValueOp loop, + PatternRewriter &rewriter) const final { + auto module = loop->getParentOfType(); + + SymbolTableCollection symbolTable; + symbolTable.getSymbolTable(module); + + // 1) make sure the postgres_getresult function is declared + auto valuefn = dyn_cast_or_null( + symbolTable.lookupSymbolIn(module, rewriter.getStringAttr("PQgetvalue"))); + + auto atoifn = dyn_cast_or_null( + symbolTable.lookupSymbolIn(module, rewriter.getStringAttr("atoi"))); + + // 2) convert the args to valid args to postgres_getresult abi + Value handle = loop.getHandle(); + handle = rewriter.create(loop.getLoc(), + rewriter.getI64Type(), handle); + handle = rewriter.create( + loop.getLoc(), LLVM::LLVMPointerType::get(rewriter.getI8Type()), handle); + + Value row = loop.getRow(); + Value column = loop.getColumn(); + + + // 3) call and replace + Value args[] = {handle, row, column}; + + Value res = + rewriter.create(loop.getLoc(), valuefn, args) + ->getResult(0); + + Value args2[] = {res}; + + Value res2 = + rewriter.create(loop.getLoc(), atoifn, args2) + ->getResult(0); + + if (loop.getType() != res2.getType()) { + if (loop.getType().isa()) + res2 = rewriter.create(loop.getLoc(), + loop.getType(), res2); + else if (auto IT = loop.getType().dyn_cast()) { + auto IT2 = res2.getType().dyn_cast(); + if (IT.getWidth() < IT2.getWidth()) { + res2 = rewriter.create(loop.getLoc(), + loop.getType(), res2); + } else if (IT.getWidth() > IT2.getWidth()) { + res2 = rewriter.create(loop.getLoc(), + loop.getType(), res2); + } else assert(0 && "illegal integer type conversion"); + } else { + assert(0 && "illegal type conversion"); + } + } + rewriter.replaceOp(loop, res2); + + return success(); + } +}; + + + +struct ExecuteOpLowering : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(sql::ExecuteOp loop, + PatternRewriter &rewriter) const final { + auto module = loop->getParentOfType(); + + SymbolTableCollection symbolTable; + symbolTable.getSymbolTable(module); + + // 1) make sure the postgres_getresult function is declared + auto executefn = dyn_cast_or_null( + symbolTable.lookupSymbolIn(module, rewriter.getStringAttr("PQexec"))); + + // 2) convert the args to valid args to postgres_getresult abi + Value conn = loop.getConn(); + conn = rewriter.create(loop.getLoc(), + rewriter.getI64Type(), conn); + conn = rewriter.create( + loop.getLoc(), LLVM::LLVMPointerType::get(rewriter.getI8Type()), conn); + + Value command = loop.getCommand(); + + // 3) call and replace + Value args[] = {conn, command}; + + Value res = + rewriter.create(loop.getLoc(), executefn, args) + ->getResult(0); + + rewriter.replaceOpWithNewOp( + loop, rewriter.getIndexType(), res); + return success(); } }; @@ -103,9 +206,35 @@ void SQLLower::runOnOperation() { builder.getFunctionType(argtypes, rettypes)); SymbolTable::setSymbolVisibility(fn, SymbolTable::Visibility::Private); } + + if (!dyn_cast_or_null(symbolTable.lookupSymbolIn( + module, builder.getStringAttr("PQgetvalue")))) { + mlir::Type argtypes[] = { + LLVM::LLVMPointerType::get(builder.getI8Type()), + LLVM::LLVMPointerType::get(builder.getI64Type()), + LLVM::LLVMPointerType::get(builder.getI64Type())}; + mlir::Type rettypes[] = {LLVM::LLVMPointerType::get(builder.getI8Type())}; + + auto fn = + builder.create(module.getLoc(), "PQgetvalue", + builder.getFunctionType(argtypes, rettypes)); + SymbolTable::setSymbolVisibility(fn, SymbolTable::Visibility::Private); + } + + // if (!dyn_cast_or_null( + // symbolTable.lookupSymbolIn(module, builder.getStringAttr("PQexec")))) { + // mlir::Type argtypes[] = {LLVM::LLVMPointerType::get(builder.getI32Type()), + // LLVM::LLVMPointerType::get(builder.getI8Type())}; + // mlir::Type rettypes[] = {LLVM::LLVMPointerType::get(builder.getI32Type())}; + + // auto fn = builder.create( + // module.getLoc(), "PQexec", builder.getFunctionType(argtypes, rettypes)); + // SymbolTable::setSymbolVisibility(fn, SymbolTable::Visibility::Private); + // } + if (!dyn_cast_or_null( symbolTable.lookupSymbolIn(module, builder.getStringAttr("atoi")))) { - mlir::Type argtypes[] = {LLVM::LLVMPointerType::get(builder.getI8Type())}; + mlir::Type argtypes[] = {MemRefType::get({-1}, builder.getI8Type())}; // todo use data layout mlir::Type rettypes[] = {builder.getI64Type()}; @@ -117,6 +246,8 @@ void SQLLower::runOnOperation() { RewritePatternSet patterns(&getContext()); patterns.insert(&getContext()); + patterns.insert(&getContext()); + // patterns.insert(&getContext()); GreedyRewriteConfig config; (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), diff --git a/lib/sql/Passes/SQLRaising.cpp b/lib/sql/Passes/SQLRaising.cpp index c46270072010..f843cc120a34 100644 --- a/lib/sql/Passes/SQLRaising.cpp +++ b/lib/sql/Passes/SQLRaising.cpp @@ -77,6 +77,90 @@ struct PQcmdTuplesRaising : public OpRewritePattern { } }; + + +struct PQgetvalueRaising : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(func::CallOp call, + PatternRewriter &rewriter) const final { + if (call.getCallee() != "PQgetvalue") { + return failure(); + } + SymbolTableCollection symbolTable; + symbolTable.getSymbolTable(call); + auto module = call->getParentOfType(); + + // 2) convert the args to valid args to postgres_getresult abi + Value handle = call.getArgOperands()[0]; + handle = rewriter.create( + call.getLoc(), rewriter.getIntegerType(64), handle); + + handle = rewriter.create(call.getLoc(), + rewriter.getIndexType(), handle); + + Value row = call.getArgOperands()[1]; + Value column = call.getArgOperands()[2]; + + Value res = rewriter.create(call.getLoc(), rewriter.getIndexType(), handle, row, column); + // or Value res = rewriter.create(call.getLoc(), rewriter.getIndexType(), {handle, row, column}); + + res = rewriter.create(call.getLoc(), + rewriter.getI64Type(), res); + + auto itoafn = dyn_cast_or_null( + symbolTable.lookupSymbolIn(module, rewriter.getStringAttr("itoa"))); + + Value args2[] = {res}; + + rewriter.replaceOpWithNewOp(call, itoafn, args2); + + // 4) done + return success(); + } +}; + + +struct PQexecRaising : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(func::CallOp call, + PatternRewriter &rewriter) const final { + if (call.getCallee() != "PQexec") { + return failure(); + } + SymbolTableCollection symbolTable; + symbolTable.getSymbolTable(call); + auto module = call->getParentOfType(); + + // 2) convert the args to valid args to postgres_getresult abi + Value conn = call.getArgOperands()[0]; + conn = rewriter.create( + call.getLoc(), rewriter.getIntegerType(64), conn); + + conn = rewriter.create(call.getLoc(), + rewriter.getIndexType(), conn); + + Value command = call.getArgOperands()[1]; + command = rewriter.create( + call.getLoc(), rewriter.getIntegerType(64), command); + + command = rewriter.create(call.getLoc(), + rewriter.getIndexType(), command); + + Value res = rewriter.create(call.getLoc(), rewriter.getIndexType(), conn, command); + + res = rewriter.create(call.getLoc(), + rewriter.getI64Type(), res); + + rewriter.replaceOp(call, res); + /// rewriter.replaceOpWithNewOp(call, itoafn, res); + + // 4) done + return success(); + } +}; + void SQLRaising::runOnOperation() { auto module = getOperation(); SymbolTableCollection symbolTable; @@ -98,6 +182,8 @@ void SQLRaising::runOnOperation() { RewritePatternSet patterns(&getContext()); patterns.insert(&getContext()); + patterns.insert(&getContext()); + // patterns.insert(&getContext()); GreedyRewriteConfig config; (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), diff --git a/test_with_pragma.c b/test_with_pragma.c index 2a852b730804..d12558a64839 100644 --- a/test_with_pragma.c +++ b/test_with_pragma.c @@ -8,8 +8,14 @@ #pragma lower_to(num_rows_fn, "sql.num_results") int num_rows_fn(size_t);// char* +#pragma lower_to(get_value_fn_int, "sql.get_value") +int get_value_fn_int(size_t, int, int); + +#pragma lower_to(get_value_fn_double, "sql.get_value") +double get_value_fn_double(size_t, int, int); + + void do_exit(PGconn *conn) { - PQfinish(conn); exit(1); } @@ -35,6 +41,7 @@ int main() { } printf("%s\n", PQgetvalue(res, 0, 0)); + printf("%d\n", get_value_fn_int((size_t)res, 0, 0)); printf("%d\n", num_rows_fn((size_t)res)); // res, 0, 0)); diff --git a/tools/cgeist/CMakeLists.txt b/tools/cgeist/CMakeLists.txt index 1b0e7434c773..30dd13d116b7 100644 --- a/tools/cgeist/CMakeLists.txt +++ b/tools/cgeist/CMakeLists.txt @@ -59,6 +59,7 @@ target_compile_definitions(cgeist PUBLIC -DLLVM_OBJ_ROOT="${LLVM_BINARY_DIR}") target_link_libraries(cgeist PRIVATE MLIRSCFTransforms MLIRPolygeist + MLIRSQL MLIRSupport MLIRIR @@ -76,6 +77,7 @@ target_link_libraries(cgeist PRIVATE MLIRMathToLLVM MLIRTargetLLVMIRImport MLIRPolygeistTransforms + MLIRSQLTransforms MLIRLLVMToLLVMIRTranslation MLIRSCFToOpenMP MLIROpenMPToLLVM @@ -109,4 +111,5 @@ target_link_libraries(cgeist PRIVATE clangSerialization ) add_dependencies(cgeist MLIRPolygeistOpsIncGen MLIRPolygeistPassIncGen) +add_dependencies(cgeist MLIRSQLOpsIncGen MLIRSQLPassIncGen) add_subdirectory(Test) diff --git a/tools/cgeist/Lib/clang-mlir.cc b/tools/cgeist/Lib/clang-mlir.cc index 9af4714326da..34dfad6dfa6d 100644 --- a/tools/cgeist/Lib/clang-mlir.cc +++ b/tools/cgeist/Lib/clang-mlir.cc @@ -5606,6 +5606,10 @@ mlir::Type MLIRASTConsumer::getMLIRType(clang::QualType qt, bool *implicitRef, return retTy; } + if (ST->isOpaque()) { + OpBuilder builder(module->getContext()); + types.push_back(builder.getIntegerType(8)); + } if (!types.size()) { RT->dump(); llvm::errs() << "ST: " << *ST << "\n"; diff --git a/tools/cgeist/driver.cc b/tools/cgeist/driver.cc index 4fdd907ddff1..e58c13fcddcc 100644 --- a/tools/cgeist/driver.cc +++ b/tools/cgeist/driver.cc @@ -73,7 +73,9 @@ #include #include "polygeist/Dialect.h" +#include "sql/SQLDialect.h" #include "polygeist/Passes/Passes.h" +#include "sql/Passes/Passes.h" #include @@ -550,6 +552,7 @@ int main(int argc, char **argv) { context.getOrLoadDialect(); context.getOrLoadDialect(); context.getOrLoadDialect(); + context.getOrLoadDialect(); LLVM::LLVMFunctionType::attachInterface(context); LLVM::LLVMPointerType::attachInterface(context); diff --git a/tools/polygeist-opt/polygeist-opt.cpp b/tools/polygeist-opt/polygeist-opt.cpp index b642cc5aa076..60d1241c0d3e 100644 --- a/tools/polygeist-opt/polygeist-opt.cpp +++ b/tools/polygeist-opt/polygeist-opt.cpp @@ -35,6 +35,7 @@ #include "sql/SQLDialect.h" #include "sql/SQLOps.h" +#include "sql/Passes/Passes.h" using namespace mlir; @@ -68,6 +69,7 @@ int main(int argc, char **argv) { registry.insert(); mlir::registerpolygeistPasses(); mlir::func::registerInlinerExtension(registry); + mlir::registersqlPasses(); // Register the standard passes we want. mlir::registerCSEPass(); From 548d94a9546194229e053d8077f1a37a03b6965e Mon Sep 17 00:00:00 2001 From: Carl Guo Date: Sat, 17 Jun 2023 18:34:00 -0400 Subject: [PATCH 09/15] lowering still buggy --- lib/sql/Ops.cpp | 8 +- lib/sql/Passes/SQLLower.cpp | 190 +++++++++++++++++++++++------------- 2 files changed, 127 insertions(+), 71 deletions(-) diff --git a/lib/sql/Ops.cpp b/lib/sql/Ops.cpp index 332c0e61075a..77cb65b5ff61 100644 --- a/lib/sql/Ops.cpp +++ b/lib/sql/Ops.cpp @@ -13,6 +13,7 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/OpImplementation.h" #include "mlir/Interfaces/SideEffectInterfaces.h" + #include "sql/SQLDialect.h" #include "sql/SQLOps.h" @@ -29,6 +30,7 @@ #include "mlir/IR/Dominance.h" #include "mlir/IR/IntegerSet.h" #include "mlir/Transforms/SideEffectUtils.h" +// #include "mlir/Target/LLVMIR/LLVMTranslationInterface.h" #include "llvm/ADT/SetVector.h" #include "llvm/Support/Debug.h" @@ -94,7 +96,7 @@ class NumResultsOpTypeFix final : public OpRewritePattern { Value handle = op->getOperand(0); if (handle.getType().isa() && op->getResultTypes()[0].isa()) - return failure(); + return failure(); if (!handle.getType().isa()) { handle = rewriter.create(op.getLoc(), @@ -105,9 +107,9 @@ class NumResultsOpTypeFix final : public OpRewritePattern { mlir::Value res = rewriter.create(op.getLoc(), rewriter.getIndexType(), handle); if (op->getResultTypes()[0].isa()) { - rewriter.replaceOp(op, res); + rewriter.replaceOp(op, res); } else { - rewriter.replaceOpWithNewOp(op, op->getResultTypes()[0], res); + rewriter.replaceOpWithNewOp(op, op->getResultTypes()[0], res); } return success(changed); diff --git a/lib/sql/Passes/SQLLower.cpp b/lib/sql/Passes/SQLLower.cpp index 8637ce277a42..77c010def2f3 100644 --- a/lib/sql/Passes/SQLLower.cpp +++ b/lib/sql/Passes/SQLLower.cpp @@ -40,45 +40,98 @@ struct SQLLower : public SQLLowerBase { } // end anonymous namespace + +struct ExecuteOpLowering : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(sql::ExecuteOp op, + PatternRewriter &rewriter) const final { + auto module = op->getParentOfType(); + + SymbolTableCollection symbolTable; + symbolTable.getSymbolTable(module); + + // 1) make sure the postgres_getresult function is declared + auto execfn = dyn_cast_or_null( + symbolTable.lookupSymbolIn(module, rewriter.getStringAttr("PQexec"))); + + auto atoifn = dyn_cast_or_null( + symbolTable.lookupSymbolIn(module, rewriter.getStringAttr("atoi"))); + + // 2) convert the args to valid args to postgres_getresult abi + Value conn = op.getConn(); + conn = rewriter.create(op.getLoc(), + rewriter.getI64Type(), conn); + conn = rewriter.create( + op.getLoc(), LLVM::LLVMPointerType::get(rewriter.getI8Type()), conn); + + Value command = op.getCommand(); + command = rewriter.create(op.getLoc(), + rewriter.getI64Type(), command); + command = rewriter.create( + op.getLoc(), LLVM::LLVMPointerType::get(rewriter.getI8Type()), command); + + // 3) call and replace + Value args[] = {conn, command}; + + Value res = + rewriter.create(op.getLoc(), execfn, args) + ->getResult(0); + + Value args2[] = {res}; + + Value res2 = + rewriter.create(op.getLoc(), atoifn, args2) + ->getResult(0); + + rewriter.replaceOpWithNewOp( + op, rewriter.getIndexType(), res2); + + return success(); + } +}; + + struct NumResultsOpLowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(sql::NumResultsOp loop, + LogicalResult matchAndRewrite(sql::NumResultsOp op, PatternRewriter &rewriter) const final { - auto module = loop->getParentOfType(); + auto module = op->getParentOfType(); SymbolTableCollection symbolTable; symbolTable.getSymbolTable(module); // 1) make sure the postgres_getresult function is declared auto rowsfn = dyn_cast_or_null( - symbolTable.lookupSymbolIn(module, rewriter.getStringAttr("PQcmdTuples"))); + symbolTable.lookupSymbolIn(module, rewriter.getStringAttr("PQntuples"))); auto atoifn = dyn_cast_or_null( symbolTable.lookupSymbolIn(module, rewriter.getStringAttr("atoi"))); // 2) convert the args to valid args to postgres_getresult abi - Value arg = loop.getHandle(); - arg = rewriter.create(loop.getLoc(), + Value arg = op.getHandle(); + arg = rewriter.create(op.getLoc(), rewriter.getI64Type(), arg); + arg = rewriter.create( - loop.getLoc(), LLVM::LLVMPointerType::get(rewriter.getI8Type()), arg); + op.getLoc(), LLVM::LLVMPointerType::get(rewriter.getI8Type()), arg); // 3) call and replace Value args[] = {arg}; Value res = - rewriter.create(loop.getLoc(), rowsfn, args) + rewriter.create(op.getLoc(), rowsfn, args) ->getResult(0); Value args2[] = {res}; Value res2 = - rewriter.create(loop.getLoc(), atoifn, args2) + rewriter.create(op.getLoc(), atoifn, args2) ->getResult(0); rewriter.replaceOpWithNewOp( - loop, rewriter.getIndexType(), res2); + op, rewriter.getIndexType(), res2); return success(); } @@ -88,9 +141,9 @@ struct NumResultsOpLowering : public OpRewritePattern { struct GetValueOpLowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(sql::GetValueOp loop, + LogicalResult matchAndRewrite(sql::GetValueOp op, PatternRewriter &rewriter) const final { - auto module = loop->getParentOfType(); + auto module = op->getParentOfType(); SymbolTableCollection symbolTable; symbolTable.getSymbolTable(module); @@ -103,47 +156,47 @@ struct GetValueOpLowering : public OpRewritePattern { symbolTable.lookupSymbolIn(module, rewriter.getStringAttr("atoi"))); // 2) convert the args to valid args to postgres_getresult abi - Value handle = loop.getHandle(); - handle = rewriter.create(loop.getLoc(), + Value handle = op.getHandle(); + handle = rewriter.create(op.getLoc(), rewriter.getI64Type(), handle); handle = rewriter.create( - loop.getLoc(), LLVM::LLVMPointerType::get(rewriter.getI8Type()), handle); + op.getLoc(), LLVM::LLVMPointerType::get(rewriter.getI8Type()), handle); - Value row = loop.getRow(); - Value column = loop.getColumn(); + Value row = op.getRow(); + Value column = op.getColumn(); // 3) call and replace Value args[] = {handle, row, column}; Value res = - rewriter.create(loop.getLoc(), valuefn, args) + rewriter.create(op.getLoc(), valuefn, args) ->getResult(0); Value args2[] = {res}; Value res2 = - rewriter.create(loop.getLoc(), atoifn, args2) + rewriter.create(op.getLoc(), atoifn, args2) ->getResult(0); - if (loop.getType() != res2.getType()) { - if (loop.getType().isa()) - res2 = rewriter.create(loop.getLoc(), - loop.getType(), res2); - else if (auto IT = loop.getType().dyn_cast()) { + if (op.getType() != res2.getType()) { + if (op.getType().isa()) + res2 = rewriter.create(op.getLoc(), + op.getType(), res2); + else if (auto IT = op.getType().dyn_cast()) { auto IT2 = res2.getType().dyn_cast(); if (IT.getWidth() < IT2.getWidth()) { - res2 = rewriter.create(loop.getLoc(), - loop.getType(), res2); + res2 = rewriter.create(op.getLoc(), + op.getType(), res2); } else if (IT.getWidth() > IT2.getWidth()) { - res2 = rewriter.create(loop.getLoc(), - loop.getType(), res2); + res2 = rewriter.create(op.getLoc(), + op.getType(), res2); } else assert(0 && "illegal integer type conversion"); } else { assert(0 && "illegal type conversion"); } } - rewriter.replaceOp(loop, res2); + rewriter.replaceOp(op, res2); return success(); } @@ -151,42 +204,42 @@ struct GetValueOpLowering : public OpRewritePattern { -struct ExecuteOpLowering : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +// struct ExecuteOpLowering : public OpRewritePattern { +// using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(sql::ExecuteOp loop, - PatternRewriter &rewriter) const final { - auto module = loop->getParentOfType(); +// LogicalResult matchAndRewrite(sql::ExecuteOp op, +// PatternRewriter &rewriter) const final { +// auto module = op->getParentOfType(); - SymbolTableCollection symbolTable; - symbolTable.getSymbolTable(module); +// SymbolTableCollection symbolTable; +// symbolTable.getSymbolTable(module); - // 1) make sure the postgres_getresult function is declared - auto executefn = dyn_cast_or_null( - symbolTable.lookupSymbolIn(module, rewriter.getStringAttr("PQexec"))); +// // 1) make sure the postgres_getresult function is declared +// auto executefn = dyn_cast_or_null( +// symbolTable.lookupSymbolIn(module, rewriter.getStringAttr("PQexec"))); - // 2) convert the args to valid args to postgres_getresult abi - Value conn = loop.getConn(); - conn = rewriter.create(loop.getLoc(), - rewriter.getI64Type(), conn); - conn = rewriter.create( - loop.getLoc(), LLVM::LLVMPointerType::get(rewriter.getI8Type()), conn); +// // 2) convert the args to valid args to postgres_getresult abi +// Value conn = op.getConn(); +// conn = rewriter.create(op.getLoc(), +// rewriter.getI64Type(), conn); +// conn = rewriter.create( +// op.getLoc(), LLVM::LLVMPointerType::get(rewriter.getI8Type()), conn); - Value command = loop.getCommand(); +// Value command = op.getCommand(); - // 3) call and replace - Value args[] = {conn, command}; +// // 3) call and replace +// Value args[] = {conn, command}; - Value res = - rewriter.create(loop.getLoc(), executefn, args) - ->getResult(0); +// Value res = +// rewriter.create(op.getLoc(), executefn, args) +// ->getResult(0); - rewriter.replaceOpWithNewOp( - loop, rewriter.getIndexType(), res); +// rewriter.replaceOpWithNewOp( +// op, rewriter.getIndexType(), res); - return success(); - } -}; +// return success(); +// } +// }; void SQLLower::runOnOperation() { auto module = getOperation(); @@ -197,12 +250,12 @@ void SQLLower::runOnOperation() { builder.setInsertionPointToStart(module.getBody()); if (!dyn_cast_or_null(symbolTable.lookupSymbolIn( - module, builder.getStringAttr("PQcmdTuples")))) { + module, builder.getStringAttr("PQntuples")))) { mlir::Type argtypes[] = {LLVM::LLVMPointerType::get(builder.getI8Type())}; - mlir::Type rettypes[] = {LLVM::LLVMPointerType::get(builder.getI8Type())}; + mlir::Type rettypes[] = {LLVM::LLVMPointerType::get(builder.getI64Type())}; auto fn = - builder.create(module.getLoc(), "PQcmdTuples", + builder.create(module.getLoc(), "PQntuples", builder.getFunctionType(argtypes, rettypes)); SymbolTable::setSymbolVisibility(fn, SymbolTable::Visibility::Private); } @@ -221,20 +274,21 @@ void SQLLower::runOnOperation() { SymbolTable::setSymbolVisibility(fn, SymbolTable::Visibility::Private); } - // if (!dyn_cast_or_null( - // symbolTable.lookupSymbolIn(module, builder.getStringAttr("PQexec")))) { - // mlir::Type argtypes[] = {LLVM::LLVMPointerType::get(builder.getI32Type()), - // LLVM::LLVMPointerType::get(builder.getI8Type())}; - // mlir::Type rettypes[] = {LLVM::LLVMPointerType::get(builder.getI32Type())}; +// if (!dyn_cast_or_null( +// symbolTable.lookupSymbolIn(module, builder.getStringAttr("PQexec")))) { +// mlir::Type argtypes[] = {LLVM::LLVMPointerType::get(builder.getI64Type()), +// LLVM::LLVMPointerType::get(builder.getI8Type())}; +// mlir::Type rettypes[] = {LLVM::LLVMPointerType::get(builder.getI64Type())}; - // auto fn = builder.create( - // module.getLoc(), "PQexec", builder.getFunctionType(argtypes, rettypes)); - // SymbolTable::setSymbolVisibility(fn, SymbolTable::Visibility::Private); - // } +// auto fn = builder.create( +// module.getLoc(), "PQexec", builder.getFunctionType(argtypes, rettypes)); +// SymbolTable::setSymbolVisibility(fn, SymbolTable::Visibility::Private); +// } if (!dyn_cast_or_null( symbolTable.lookupSymbolIn(module, builder.getStringAttr("atoi")))) { - mlir::Type argtypes[] = {MemRefType::get({-1}, builder.getI8Type())}; + // mlir::Type argtypes[] = {MemRefType::get({-1}, builder.getI8Type())}; + mlir::Type argtypes[] = {LLVM::LLVMPointerType::get(builder.getI64Type())}; // todo use data layout mlir::Type rettypes[] = {builder.getI64Type()}; From 5bed072cd7f159201d9e45579ccfcaadc81b7566 Mon Sep 17 00:00:00 2001 From: Carl Guo Date: Mon, 3 Jul 2023 15:58:23 -0400 Subject: [PATCH 10/15] select op --- include/sql/SQLOps.td | 65 ++++++++++-- include/sql/SQLTypes.h | 23 +++++ include/sql/SQLTypes.td | 20 ++++ lib/sql/Ops.cpp | 57 ++++++++++- lib/sql/Passes/SQLLower.cpp | 180 +++++++++++++--------------------- lib/sql/Passes/SQLRaising.cpp | 127 +++++++++++------------- 6 files changed, 286 insertions(+), 186 deletions(-) create mode 100644 include/sql/SQLTypes.h create mode 100644 include/sql/SQLTypes.td diff --git a/include/sql/SQLOps.td b/include/sql/SQLOps.td index 002c8444849c..db96c52a293b 100644 --- a/include/sql/SQLOps.td +++ b/include/sql/SQLOps.td @@ -12,27 +12,67 @@ include "mlir/IR/AttrTypeBase.td" include "SQLDialect.td" +def IntOp : SQL_Op<"int", [Pure]> { + let summary = "select"; + + let arguments = (ins StrAttr:$expr); + let results = (outs SQLExprType:$result); + + let hasFolder = 0; + let hasCanonicalizer = 0; +} + +def ColumnOp : SQL_Op<"column", [Pure]> { + let summary = "select"; + + let arguments = (ins StrAttr:$expr); + let results = (outs SQLExprType:$result); + + let hasFolder = 0; + let hasCanonicalizer = 0; +} + +def TableOp : SQL_Op<"table", [Pure]> { + let summary = "select"; + + let arguments = (ins StrAttr:$expr); + let results = (outs SQLExprType:$result); + + let hasFolder = 0; + let hasCanonicalizer = 0; +} + def SelectOp : SQL_Op<"select", [Pure]> { let summary = "select"; - // TODO: limit (optional), where clauses, join, etc - let arguments = (ins StrArrayAttr:$column, StrAttr:$table); - let results = (outs Index : $result); + let arguments = (ins Variadic:$columns, Optional:$table); + let results = (outs SQLExprType:$result); let hasFolder = 0; let hasCanonicalizer = 0; } +def UnparsedOp : SQL_Op<"unparsed", [Pure]> { + let summary = "unparsed sql op"; + + let arguments = (ins AnyType:$input); + let results = (outs SQLExprType:$result); + + let hasFolder = 0; + let hasCanonicalizer = 1; +} + def ExecuteOp : SQL_Op<"execute", []> { let summary = "execute query"; - let arguments = (ins Index:$conn, Index:$command); + let arguments = (ins Index:$conn, SQLExprType:$command); let results = (outs Index:$result); let hasFolder = 0; - let hasCanonicalizer = 0; + let hasCanonicalizer = 1; } + def NumResultsOp : SQL_Op<"num_results", [Pure]> { let summary = "number of results"; @@ -52,5 +92,18 @@ def GetValueOp : SQL_Op<"get_value", [Pure]> { let hasFolder = 0; let hasCanonicalizer = 1; } - + +// def SelectOp : SQL_Op<"select", [Pure]>{ +// let summary = "select"; + +// let arguments = (ins StrArrayAttr:$columns, +// Optional:$from); +// // optional>:$where, +// // optional:$limit, +// // optional:$order); +// let results = (outs AnyType:$result); + +// let hasFolder = 0; +// let hasCanonicalizer = 0; +// } #endif // SQL_OPS \ No newline at end of file diff --git a/include/sql/SQLTypes.h b/include/sql/SQLTypes.h new file mode 100644 index 000000000000..501b697df78f --- /dev/null +++ b/include/sql/SQLTypes.h @@ -0,0 +1,23 @@ +//===- SQLOps.h - SQL dialect ops --------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef SQLTYPES_H +#define SQLTYPES_H + +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Interfaces/ViewLikeInterface.h" +#include "llvm/Support/CommandLine.h" + +#define GET_TYPE_CLASSES +#include "sql/SQLTypes.h.inc" + +#endif \ No newline at end of file diff --git a/include/sql/SQLTypes.td b/include/sql/SQLTypes.td new file mode 100644 index 000000000000..a98fe7d3def8 --- /dev/null +++ b/include/sql/SQLTypes.td @@ -0,0 +1,20 @@ +//===- SQLTypes.td - SQL dialect types ----------------*- tablegen -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef SQL_TYPES +#define SQL_TYPES + +include "mlir/IR/AttrTypeBase.td" +include "SQLDialect.td" + +def SQLExprType : SQL_Type<"Expr", "expr"> { + let summary = "SQL expression type"; +} + + +#endif // SQL_TYPES \ No newline at end of file diff --git a/lib/sql/Ops.cpp b/lib/sql/Ops.cpp index 77cb65b5ff61..45bc8264f22e 100644 --- a/lib/sql/Ops.cpp +++ b/lib/sql/Ops.cpp @@ -16,6 +16,7 @@ #include "sql/SQLDialect.h" #include "sql/SQLOps.h" +#include "polygeist/Ops.h" #define GET_OP_CLASSES #include "sql/SQLOps.cpp.inc" @@ -30,7 +31,6 @@ #include "mlir/IR/Dominance.h" #include "mlir/IR/IntegerSet.h" #include "mlir/Transforms/SideEffectUtils.h" -// #include "mlir/Target/LLVMIR/LLVMTranslationInterface.h" #include "llvm/ADT/SetVector.h" #include "llvm/Support/Debug.h" @@ -52,7 +52,6 @@ class GetValueOpTypeFix final : public OpRewritePattern { bool changed = false; Value handle = op.getOperand(0); - if (!handle.getType().isa()) { handle = rewriter.create(op.getLoc(), rewriter.getIndexType(), handle); @@ -119,4 +118,58 @@ class NumResultsOpTypeFix final : public OpRewritePattern { void NumResultsOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.insert(context); +} + + + +class ExecuteOpTypeFix final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ExecuteOp op, + PatternRewriter &rewriter) const override { + bool changed = false; + + Value conn = op->getOperand(0); + Value command = op->getOperand(1); + + if (conn.getType().isa() && command.getType().isa() && op->getResultTypes()[0].isa()) + return failure(); + + if (!conn.getType().isa()) { + conn = rewriter.create(op.getLoc(), + rewriter.getIndexType(), conn); + changed = true; + } + if (command.getType().isa()) { + command = rewriter.create(op.getLoc(), + LLVM::LLVMPointerType::get(rewriter.getI8Type()), command); + changed = true; + } + if (command.getType().isa()) { + command = rewriter.create(op.getLoc(), + rewriter.getI64Type(), command); + changed = true; + } + if (!command.getType().isa()) { + command = rewriter.create(op.getLoc(), + rewriter.getIndexType(), command); + changed = true; + } + + if (!changed) return failure(); + mlir::Value res = rewriter.create(op.getLoc(), rewriter.getIndexType(), conn, command); + rewriter.replaceOp(op, res); + // if (op->getResultTypes()[0].isa()) { + // rewriter.replaceOp(op, res); + // } else { + // rewriter.replaceOpWithNewOp(op, op->getResultTypes()[0], res); + // } + return success(changed); + } +}; + +void ExecuteOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.insert(context); } \ No newline at end of file diff --git a/lib/sql/Passes/SQLLower.cpp b/lib/sql/Passes/SQLLower.cpp index 77c010def2f3..9dc30f9f75ba 100644 --- a/lib/sql/Passes/SQLLower.cpp +++ b/lib/sql/Passes/SQLLower.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "PassDetails.h" +#include "polygeist/Ops.h" #include "mlir/Analysis/CallGraph.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Async/IR/Async.h" @@ -41,57 +42,6 @@ struct SQLLower : public SQLLowerBase { } // end anonymous namespace -struct ExecuteOpLowering : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(sql::ExecuteOp op, - PatternRewriter &rewriter) const final { - auto module = op->getParentOfType(); - - SymbolTableCollection symbolTable; - symbolTable.getSymbolTable(module); - - // 1) make sure the postgres_getresult function is declared - auto execfn = dyn_cast_or_null( - symbolTable.lookupSymbolIn(module, rewriter.getStringAttr("PQexec"))); - - auto atoifn = dyn_cast_or_null( - symbolTable.lookupSymbolIn(module, rewriter.getStringAttr("atoi"))); - - // 2) convert the args to valid args to postgres_getresult abi - Value conn = op.getConn(); - conn = rewriter.create(op.getLoc(), - rewriter.getI64Type(), conn); - conn = rewriter.create( - op.getLoc(), LLVM::LLVMPointerType::get(rewriter.getI8Type()), conn); - - Value command = op.getCommand(); - command = rewriter.create(op.getLoc(), - rewriter.getI64Type(), command); - command = rewriter.create( - op.getLoc(), LLVM::LLVMPointerType::get(rewriter.getI8Type()), command); - - // 3) call and replace - Value args[] = {conn, command}; - - Value res = - rewriter.create(op.getLoc(), execfn, args) - ->getResult(0); - - Value args2[] = {res}; - - Value res2 = - rewriter.create(op.getLoc(), atoifn, args2) - ->getResult(0); - - rewriter.replaceOpWithNewOp( - op, rewriter.getIndexType(), res2); - - return success(); - } -}; - - struct NumResultsOpLowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -106,16 +56,16 @@ struct NumResultsOpLowering : public OpRewritePattern { auto rowsfn = dyn_cast_or_null( symbolTable.lookupSymbolIn(module, rewriter.getStringAttr("PQntuples"))); - auto atoifn = dyn_cast_or_null( - symbolTable.lookupSymbolIn(module, rewriter.getStringAttr("atoi"))); - // 2) convert the args to valid args to postgres_getresult abi Value arg = op.getHandle(); arg = rewriter.create(op.getLoc(), - rewriter.getI64Type(), arg); + rewriter.getI8Type(), arg); arg = rewriter.create( op.getLoc(), LLVM::LLVMPointerType::get(rewriter.getI8Type()), arg); + + arg = rewriter.create(op.getLoc(), + rowsfn.getFunctionType().getInput(0), arg); // 3) call and replace Value args[] = {arg}; @@ -124,14 +74,8 @@ struct NumResultsOpLowering : public OpRewritePattern { rewriter.create(op.getLoc(), rowsfn, args) ->getResult(0); - Value args2[] = {res}; - - Value res2 = - rewriter.create(op.getLoc(), atoifn, args2) - ->getResult(0); - rewriter.replaceOpWithNewOp( - op, rewriter.getIndexType(), res2); + op, rewriter.getIndexType(), res); return success(); } @@ -162,11 +106,16 @@ struct GetValueOpLowering : public OpRewritePattern { handle = rewriter.create( op.getLoc(), LLVM::LLVMPointerType::get(rewriter.getI8Type()), handle); + handle = rewriter.create(op.getLoc(), + valuefn.getFunctionType().getInput(0), handle); + Value row = op.getRow(); + row = rewriter.create(op.getLoc(), + valuefn.getFunctionType().getInput(1), row); Value column = op.getColumn(); + column = rewriter.create(op.getLoc(), + valuefn.getFunctionType().getInput(2), column); - - // 3) call and replace Value args[] = {handle, row, column}; Value res = @@ -202,44 +151,55 @@ struct GetValueOpLowering : public OpRewritePattern { } }; +struct ExecuteOpLowering : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(sql::ExecuteOp op, + PatternRewriter &rewriter) const final { + auto module = op->getParentOfType(); -// struct ExecuteOpLowering : public OpRewritePattern { -// using OpRewritePattern::OpRewritePattern; + SymbolTableCollection symbolTable; + symbolTable.getSymbolTable(module); -// LogicalResult matchAndRewrite(sql::ExecuteOp op, -// PatternRewriter &rewriter) const final { -// auto module = op->getParentOfType(); + // 1) make sure the postgres_getresult function is declared + auto executefn = dyn_cast_or_null( + symbolTable.lookupSymbolIn(module, rewriter.getStringAttr("PQexec"))); -// SymbolTableCollection symbolTable; -// symbolTable.getSymbolTable(module); + // 2) convert the args to valid args to postgres_getresult abi + Value conn = op.getConn(); + conn = rewriter.create(op.getLoc(), + rewriter.getI8Type(), conn); + conn = rewriter.create( + op.getLoc(), LLVM::LLVMPointerType::get(rewriter.getI8Type()), conn); + conn = rewriter.create(op.getLoc(), + executefn.getFunctionType().getInput(0), conn); -// // 1) make sure the postgres_getresult function is declared -// auto executefn = dyn_cast_or_null( -// symbolTable.lookupSymbolIn(module, rewriter.getStringAttr("PQexec"))); + Value command = op.getCommand(); + command = rewriter.create(op.getLoc(), + rewriter.getI8Type(), command); + command = rewriter.create( + op.getLoc(), LLVM::LLVMPointerType::get(rewriter.getI8Type()), command); + StringRef strname = command.getDefiningOp().getGlobalName(); + Attribute strattr = dyn_cast_or_null( + symbolTable.lookupSymbolIn(module, rewriter.getStringAttr(strname))).getValueAttr(); + auto str = strattr.cast().getValue(); + llvm::errs() << str << "\n"; + command = rewriter.create(op.getLoc(), + executefn.getFunctionType().getInput(1), command); -// // 2) convert the args to valid args to postgres_getresult abi -// Value conn = op.getConn(); -// conn = rewriter.create(op.getLoc(), -// rewriter.getI64Type(), conn); -// conn = rewriter.create( -// op.getLoc(), LLVM::LLVMPointerType::get(rewriter.getI8Type()), conn); -// Value command = op.getCommand(); - -// // 3) call and replace -// Value args[] = {conn, command}; + // 3) call and replace + Value args[] = {conn, command}; -// Value res = -// rewriter.create(op.getLoc(), executefn, args) -// ->getResult(0); + Value res = + rewriter.create(op.getLoc(), executefn, args) + ->getResult(0); -// rewriter.replaceOpWithNewOp( -// op, rewriter.getIndexType(), res); + rewriter.replaceOp(op, res); -// return success(); -// } -// }; + return success(); + } +}; void SQLLower::runOnOperation() { auto module = getOperation(); @@ -251,8 +211,8 @@ void SQLLower::runOnOperation() { if (!dyn_cast_or_null(symbolTable.lookupSymbolIn( module, builder.getStringAttr("PQntuples")))) { - mlir::Type argtypes[] = {LLVM::LLVMPointerType::get(builder.getI8Type())}; - mlir::Type rettypes[] = {LLVM::LLVMPointerType::get(builder.getI64Type())}; + mlir::Type argtypes[] = {MemRefType::get({-1}, builder.getI8Type())}; + mlir::Type rettypes[] = {builder.getI64Type()}; auto fn = builder.create(module.getLoc(), "PQntuples", @@ -263,10 +223,10 @@ void SQLLower::runOnOperation() { if (!dyn_cast_or_null(symbolTable.lookupSymbolIn( module, builder.getStringAttr("PQgetvalue")))) { mlir::Type argtypes[] = { - LLVM::LLVMPointerType::get(builder.getI8Type()), - LLVM::LLVMPointerType::get(builder.getI64Type()), - LLVM::LLVMPointerType::get(builder.getI64Type())}; - mlir::Type rettypes[] = {LLVM::LLVMPointerType::get(builder.getI8Type())}; + MemRefType::get({-1}, builder.getI8Type()), + builder.getI64Type(), + builder.getI64Type()}; + mlir::Type rettypes[] = {MemRefType::get({-1}, builder.getI8Type())}; auto fn = builder.create(module.getLoc(), "PQgetvalue", @@ -274,21 +234,21 @@ void SQLLower::runOnOperation() { SymbolTable::setSymbolVisibility(fn, SymbolTable::Visibility::Private); } -// if (!dyn_cast_or_null( -// symbolTable.lookupSymbolIn(module, builder.getStringAttr("PQexec")))) { -// mlir::Type argtypes[] = {LLVM::LLVMPointerType::get(builder.getI64Type()), -// LLVM::LLVMPointerType::get(builder.getI8Type())}; -// mlir::Type rettypes[] = {LLVM::LLVMPointerType::get(builder.getI64Type())}; + if (!dyn_cast_or_null( + symbolTable.lookupSymbolIn(module, builder.getStringAttr("PQexec")))) { + mlir::Type argtypes[] = {MemRefType::get({-1}, builder.getI8Type()), + MemRefType::get({-1}, builder.getI8Type())}; + mlir::Type rettypes[] = {MemRefType::get({-1}, builder.getI8Type())}; -// auto fn = builder.create( -// module.getLoc(), "PQexec", builder.getFunctionType(argtypes, rettypes)); -// SymbolTable::setSymbolVisibility(fn, SymbolTable::Visibility::Private); -// } + auto fn = builder.create( + module.getLoc(), "PQexec", builder.getFunctionType(argtypes, rettypes)); + SymbolTable::setSymbolVisibility(fn, SymbolTable::Visibility::Private); + } if (!dyn_cast_or_null( symbolTable.lookupSymbolIn(module, builder.getStringAttr("atoi")))) { - // mlir::Type argtypes[] = {MemRefType::get({-1}, builder.getI8Type())}; - mlir::Type argtypes[] = {LLVM::LLVMPointerType::get(builder.getI64Type())}; + mlir::Type argtypes[] = {MemRefType::get({-1}, builder.getI8Type())}; + // mlir::Type argtypes[] = {LLVM::LLVMPointerType::get(builder.getI64Type())}; // todo use data layout mlir::Type rettypes[] = {builder.getI64Type()}; @@ -301,7 +261,7 @@ void SQLLower::runOnOperation() { RewritePatternSet patterns(&getContext()); patterns.insert(&getContext()); patterns.insert(&getContext()); - // patterns.insert(&getContext()); + patterns.insert(&getContext()); GreedyRewriteConfig config; (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), diff --git a/lib/sql/Passes/SQLRaising.cpp b/lib/sql/Passes/SQLRaising.cpp index f843cc120a34..8133002a45e0 100644 --- a/lib/sql/Passes/SQLRaising.cpp +++ b/lib/sql/Passes/SQLRaising.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "PassDetails.h" +#include "polygeist/Ops.h" #include "mlir/Analysis/CallGraph.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Async/IR/Async.h" @@ -40,45 +41,35 @@ struct SQLRaising : public SQLRaisingBase { } // end anonymous namespace -struct PQcmdTuplesRaising : public OpRewritePattern { +struct PQntuplesRaising : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(func::CallOp call, PatternRewriter &rewriter) const final { - if (call.getCallee() != "PQcmdTuples") { + if (call.getCallee() != "PQntuples") { return failure(); } - SymbolTableCollection symbolTable; - symbolTable.getSymbolTable(call); auto module = call->getParentOfType(); + SymbolTableCollection symbolTable; + symbolTable.getSymbolTable(module); + // 2) convert the args to valid args to postgres_getresult abi Value arg = call.getArgOperands()[0]; - arg = rewriter.create( - call.getLoc(), rewriter.getIntegerType(64), arg); + arg = rewriter.create( + call.getLoc(), LLVM::LLVMPointerType::get(rewriter.getI8Type()), arg); + arg = rewriter.create(call.getLoc(), rewriter.getIntegerType(64), arg); + arg = rewriter.create(call.getLoc(), rewriter.getIndexType(), arg); - arg = rewriter.create(call.getLoc(), - rewriter.getIndexType(), arg); Value res = rewriter.create(call.getLoc(), rewriter.getIndexType(), arg); - res = rewriter.create(call.getLoc(), rewriter.getI64Type(), res); - - auto itoafn = dyn_cast_or_null( - symbolTable.lookupSymbolIn(module, rewriter.getStringAttr("itoa"))); - - Value args2[] = {res}; - - rewriter.replaceOpWithNewOp(call, itoafn, args2); - - // 4) done + rewriter.replaceOp(call, res); return success(); } }; - - struct PQgetvalueRaising : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -87,79 +78,79 @@ struct PQgetvalueRaising : public OpRewritePattern { if (call.getCallee() != "PQgetvalue") { return failure(); } - SymbolTableCollection symbolTable; - symbolTable.getSymbolTable(call); auto module = call->getParentOfType(); + SymbolTableCollection symbolTable; + symbolTable.getSymbolTable(module); // 2) convert the args to valid args to postgres_getresult abi Value handle = call.getArgOperands()[0]; - handle = rewriter.create( - call.getLoc(), rewriter.getIntegerType(64), handle); - - handle = rewriter.create(call.getLoc(), - rewriter.getIndexType(), handle); + handle = rewriter.create( + call.getLoc(), LLVM::LLVMPointerType::get(rewriter.getI8Type()), handle); + handle = rewriter.create(call.getLoc(), rewriter.getIntegerType(64), handle); + handle = rewriter.create(call.getLoc(), rewriter.getIndexType(), handle); Value row = call.getArgOperands()[1]; + row = rewriter.create(call.getLoc(), rewriter.getIndexType(), row); Value column = call.getArgOperands()[2]; + column = rewriter.create(call.getLoc(), rewriter.getIndexType(), column); Value res = rewriter.create(call.getLoc(), rewriter.getIndexType(), handle, row, column); - // or Value res = rewriter.create(call.getLoc(), rewriter.getIndexType(), {handle, row, column}); res = rewriter.create(call.getLoc(), rewriter.getI64Type(), res); - auto itoafn = dyn_cast_or_null( - symbolTable.lookupSymbolIn(module, rewriter.getStringAttr("itoa"))); - Value args2[] = {res}; - rewriter.replaceOpWithNewOp(call, itoafn, args2); - // 4) done + auto itoafn = dyn_cast_or_null( + symbolTable.lookupSymbolIn(module, rewriter.getStringAttr("itoa"))); + + rewriter.replaceOpWithNewOp(call, itoafn, args2); + return success(); } }; -struct PQexecRaising : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +// struct PQexecRaising : public OpRewritePattern { +// using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(func::CallOp call, - PatternRewriter &rewriter) const final { - if (call.getCallee() != "PQexec") { - return failure(); - } - SymbolTableCollection symbolTable; - symbolTable.getSymbolTable(call); - auto module = call->getParentOfType(); +// LogicalResult matchAndRewrite(func::CallOp call, +// PatternRewriter &rewriter) const final { +// if (call.getCallee() != "PQexec") { +// return failure(); +// } +// SymbolTableCollection symbolTable; +// symbolTable.getSymbolTable(call); +// auto module = call->getParentOfType(); - // 2) convert the args to valid args to postgres_getresult abi - Value conn = call.getArgOperands()[0]; - conn = rewriter.create( - call.getLoc(), rewriter.getIntegerType(64), conn); +// // 2) convert the args to valid args to postgres_getresult abi +// Value conn = call.getArgOperands()[0]; +// conn = rewriter.create( +// call.getLoc(), rewriter.getIntegerType(64), conn); - conn = rewriter.create(call.getLoc(), - rewriter.getIndexType(), conn); +// conn = rewriter.create(call.getLoc(), +// rewriter.getIndexType(), conn); - Value command = call.getArgOperands()[1]; - command = rewriter.create( - call.getLoc(), rewriter.getIntegerType(64), command); +// Value command = call.getArgOperands()[1]; +// command = rewriter.create( +// call.getLoc(), rewriter.getIntegerType(64), command); - command = rewriter.create(call.getLoc(), - rewriter.getIndexType(), command); +// command = rewriter.create(call.getLoc(), +// rewriter.getIndexType(), command); - Value res = rewriter.create(call.getLoc(), rewriter.getIndexType(), conn, command); +// Value res = rewriter.create(call.getLoc(), rewriter.getIndexType(), conn, command); - res = rewriter.create(call.getLoc(), - rewriter.getI64Type(), res); +// res = rewriter.create(call.getLoc(), +// rewriter.getI64Type(), res); - rewriter.replaceOp(call, res); - /// rewriter.replaceOpWithNewOp(call, itoafn, res); +// rewriter.replaceOp(call, res); +// /// rewriter.replaceOpWithNewOp(call, itoafn, res); - // 4) done - return success(); - } -}; +// // 4) done +// return success(); +// } +// }; void SQLRaising::runOnOperation() { auto module = getOperation(); @@ -170,24 +161,24 @@ void SQLRaising::runOnOperation() { if (!dyn_cast_or_null( symbolTable.lookupSymbolIn(module, builder.getStringAttr("itoa")))) { - mlir::Type rettypes[] = {LLVM::LLVMPointerType::get(builder.getI8Type())}; - - // todo use data layout mlir::Type argtypes[] = {builder.getI64Type()}; + mlir::Type rettypes[] = {MemRefType::get({-1}, builder.getI8Type())}; auto fn = builder.create( module.getLoc(), "itoa", builder.getFunctionType(argtypes, rettypes)); SymbolTable::setSymbolVisibility(fn, SymbolTable::Visibility::Private); } + + RewritePatternSet patterns(&getContext()); - patterns.insert(&getContext()); + patterns.insert(&getContext()); patterns.insert(&getContext()); // patterns.insert(&getContext()); GreedyRewriteConfig config; (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), - config); + config); } namespace mlir { From b0f6226fb39bd1b7f13531e67ccf43d9de1b3cda Mon Sep 17 00:00:00 2001 From: Carl Guo Date: Mon, 7 Aug 2023 11:25:39 -0400 Subject: [PATCH 11/15] parser works --- include/sql/CMakeLists.txt | 1 - include/sql/Parser.h | 30 +++++ include/sql/SQLDialect.td | 15 +-- include/sql/SQLOps.h | 2 +- include/sql/SQLOps.td | 34 ++++- include/sql/SQLTypes.h | 19 +-- include/sql/SQLTypes.td | 13 +- lib/sql/CMakeLists.txt | 4 + lib/sql/Dialect.cpp | 3 + lib/sql/Ops.cpp | 122 +++++++++++------ lib/sql/Parser.cpp | 244 ++++++++++++++++++++++++++++++++++ lib/sql/Passes/SQLRaising.cpp | 106 ++++++++++----- lib/sql/Types.cpp | 37 ++++++ 13 files changed, 522 insertions(+), 108 deletions(-) create mode 100644 include/sql/Parser.h create mode 100644 lib/sql/Parser.cpp create mode 100644 lib/sql/Types.cpp diff --git a/include/sql/CMakeLists.txt b/include/sql/CMakeLists.txt index 8764d8eddbb4..1ee2b045b2cc 100644 --- a/include/sql/CMakeLists.txt +++ b/include/sql/CMakeLists.txt @@ -1,5 +1,4 @@ add_mlir_dialect(SQLOps sql) # add_mlir_doc(SQLDialect -gen-dialect-doc SQLDialect SQL/) # add_mlir_doc(SQLOps -gen-op-doc SQLOps SQL/) - add_subdirectory(Passes) \ No newline at end of file diff --git a/include/sql/Parser.h b/include/sql/Parser.h new file mode 100644 index 000000000000..99ba7579173e --- /dev/null +++ b/include/sql/Parser.h @@ -0,0 +1,30 @@ +//===- Parser.h - SQL dialect -----------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + + + + +#ifndef SQLPARSER_H +#define SQLPARSER_H + +#include "mlir/IR/Dialect.h" + +#include "mlir/IR/Value.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/Attributes.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/BuiltinTypes.h" + +#include "sql/SQLDialect.h" +#include "sql/SQLOps.h" +#include "sql/SQLTypes.h" + +mlir::Value parseSQL(mlir::Location loc, mlir::OpBuilder& builder, std::string str); + +#endif // SQLPARSER_H \ No newline at end of file diff --git a/include/sql/SQLDialect.td b/include/sql/SQLDialect.td index 0bd8af6bcb0e..8026cd4268cf 100644 --- a/include/sql/SQLDialect.td +++ b/include/sql/SQLDialect.td @@ -10,11 +10,10 @@ #define SQL_DIALECT -include "mlir/IR/OpBase.td" + include "mlir/IR/FunctionInterfaces.td" include "mlir/IR/SymbolInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" - include "mlir/Interfaces/CallInterfaces.td" include "mlir/Interfaces/CastInterfaces.td" @@ -26,15 +25,13 @@ def SQL_Dialect : Dialect { }]; let name = "sql"; let cppNamespace = "::mlir::sql"; -} - -//===----------------------------------------------------------------------===// -// SQL Operations -//===----------------------------------------------------------------------===// + let useDefaultTypePrinterParser = 1; + let extraClassDeclaration = [{ + void registerTypes(); + }]; +} -class SQL_Op traits = []> - : Op; #endif // SQL_DIALECT diff --git a/include/sql/SQLOps.h b/include/sql/SQLOps.h index 06d0319150bb..1b5c6f4783ec 100644 --- a/include/sql/SQLOps.h +++ b/include/sql/SQLOps.h @@ -16,7 +16,7 @@ #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/ViewLikeInterface.h" #include "llvm/Support/CommandLine.h" - +#include "sql/SQLTypes.h" #define GET_OP_CLASSES #include "sql/SQLOps.h.inc" diff --git a/include/sql/SQLOps.td b/include/sql/SQLOps.td index db96c52a293b..5e87f38f7d0c 100644 --- a/include/sql/SQLOps.td +++ b/include/sql/SQLOps.td @@ -9,11 +9,17 @@ #ifndef SQL_OPS #define SQL_OPS -include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/OpBase.td" include "SQLDialect.td" +include "SQLTypes.td" + + +class SQL_Op traits = []> + : Op; + def IntOp : SQL_Op<"int", [Pure]> { - let summary = "select"; + let summary = "int op"; let arguments = (ins StrAttr:$expr); let results = (outs SQLExprType:$result); @@ -23,7 +29,7 @@ def IntOp : SQL_Op<"int", [Pure]> { } def ColumnOp : SQL_Op<"column", [Pure]> { - let summary = "select"; + let summary = "column op"; let arguments = (ins StrAttr:$expr); let results = (outs SQLExprType:$result); @@ -33,7 +39,7 @@ def ColumnOp : SQL_Op<"column", [Pure]> { } def TableOp : SQL_Op<"table", [Pure]> { - let summary = "select"; + let summary = "table"; let arguments = (ins StrAttr:$expr); let results = (outs SQLExprType:$result); @@ -42,10 +48,24 @@ def TableOp : SQL_Op<"table", [Pure]> { let hasCanonicalizer = 0; } +def EmptyTableOp : SQL_Op<"empty_table", [Pure]> { + let summary = "empty_table"; + // i need to specify the size of a Variadic? + let arguments = (ins StrAttr:$expr); + let results = (outs SQLExprType:$result); + + let hasFolder = 0; + let hasCanonicalizer = 0; +} + +// def VectorNotEmpty : AttrConstraint 0">, +// "VectorNotEmpty">; + + def SelectOp : SQL_Op<"select", [Pure]> { let summary = "select"; - - let arguments = (ins Variadic:$columns, Optional:$table); + // i need to specify the size of a Variadic? + let arguments = (ins Variadic:$columns, SQLExprType:$table); let results = (outs SQLExprType:$result); let hasFolder = 0; @@ -69,7 +89,7 @@ def ExecuteOp : SQL_Op<"execute", []> { let results = (outs Index:$result); let hasFolder = 0; - let hasCanonicalizer = 1; + let hasCanonicalizer = 0; } diff --git a/include/sql/SQLTypes.h b/include/sql/SQLTypes.h index 501b697df78f..28f9d91e393a 100644 --- a/include/sql/SQLTypes.h +++ b/include/sql/SQLTypes.h @@ -1,4 +1,4 @@ -//===- SQLOps.h - SQL dialect ops --------------------*- C++ -*-===// +//===- SQLTypes.h - SQL dialect types --------------------*- C++ -*-===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,18 +6,13 @@ // //===----------------------------------------------------------------------===// -#ifndef SQLTYPES_H -#define SQLTYPES_H +#ifndef SQL_SQLTYPES_H +#define SQL_SQLTYPES_H #include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/Dialect.h" -#include "mlir/IR/OpDefinition.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Interfaces/SideEffectInterfaces.h" -#include "mlir/Interfaces/ViewLikeInterface.h" -#include "llvm/Support/CommandLine.h" -#define GET_TYPE_CLASSES -#include "sql/SQLTypes.h.inc" +#define GET_TYPEDEF_CLASSES +#include "sql/SQLOpsTypes.h.inc" -#endif \ No newline at end of file + +#endif // SQL_SQLTYPES_H \ No newline at end of file diff --git a/include/sql/SQLTypes.td b/include/sql/SQLTypes.td index a98fe7d3def8..4774a2a07810 100644 --- a/include/sql/SQLTypes.td +++ b/include/sql/SQLTypes.td @@ -12,9 +12,20 @@ include "mlir/IR/AttrTypeBase.td" include "SQLDialect.td" +class SQL_Type traits = []> + : TypeDef { + let mnemonic = typeMnemonic; +} + + + def SQLExprType : SQL_Type<"Expr", "expr"> { let summary = "SQL expression type"; -} + let description = "Custom attr or value type in sql dialect"; + // placeholder params + // let parameters = (ins StringRefParameter<"the custom value">:$value); + // let assemblyFormat = "`<` $value `>`"; +} #endif // SQL_TYPES \ No newline at end of file diff --git a/lib/sql/CMakeLists.txt b/lib/sql/CMakeLists.txt index cd4bcfde1856..fa5bae87f548 100644 --- a/lib/sql/CMakeLists.txt +++ b/lib/sql/CMakeLists.txt @@ -1,12 +1,16 @@ add_mlir_dialect_library(MLIRSQL +Types.cpp Dialect.cpp Ops.cpp +Parser.cpp + ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/sql DEPENDS MLIRSQLOpsIncGen +# MLIRSQLTypesIncGen LINK_LIBS PUBLIC MLIRIR diff --git a/lib/sql/Dialect.cpp b/lib/sql/Dialect.cpp index eb42b990bf4a..104e510bf59d 100644 --- a/lib/sql/Dialect.cpp +++ b/lib/sql/Dialect.cpp @@ -9,6 +9,7 @@ #include "mlir/IR/DialectImplementation.h" #include "sql/SQLDialect.h" #include "sql/SQLOps.h" +#include "sql/SQLTypes.h" using namespace mlir; using namespace mlir::sql; @@ -22,6 +23,8 @@ void SQLDialect::initialize() { #define GET_OP_LIST #include "sql/SQLOps.cpp.inc" >(); +registerTypes(); + } #include "sql/SQLOpsDialect.cpp.inc" diff --git a/lib/sql/Ops.cpp b/lib/sql/Ops.cpp index 45bc8264f22e..96116ad193cf 100644 --- a/lib/sql/Ops.cpp +++ b/lib/sql/Ops.cpp @@ -5,6 +5,10 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// +#include +#include +#include +#include #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" @@ -16,6 +20,8 @@ #include "sql/SQLDialect.h" #include "sql/SQLOps.h" +#include "sql/SQLTypes.h" +#include "sql/Parser.h" #include "polygeist/Ops.h" #define GET_OP_CLASSES @@ -35,6 +41,14 @@ #include "llvm/ADT/SetVector.h" #include "llvm/Support/Debug.h" + +#include "mlir/IR/Value.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/Attributes.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/BuiltinTypes.h" + #define DEBUG_TYPE "sql" using namespace mlir; @@ -122,54 +136,80 @@ void NumResultsOp::getCanonicalizationPatterns(RewritePatternSet &results, -class ExecuteOpTypeFix final : public OpRewritePattern { +// class ExecuteOpTypeFix final : public OpRewritePattern { +// public: +// using OpRewritePattern::OpRewritePattern; + +// LogicalResult matchAndRewrite(ExecuteOp op, +// PatternRewriter &rewriter) const override { +// bool changed = false; + +// Value conn = op->getOperand(0); +// Value command = op->getOperand(1); + +// if (conn.getType().isa() && command.getType().isa() && op->getResultTypes()[0].isa()) +// return failure(); + +// if (!conn.getType().isa()) { +// conn = rewriter.create(op.getLoc(), +// rewriter.getIndexType(), conn); +// changed = true; +// } +// if (command.getType().isa()) { +// command = rewriter.create(op.getLoc(), +// LLVM::LLVMPointerType::get(rewriter.getI8Type()), command); +// changed = true; +// } + + +// if (command.getType().isa()) { +// command = rewriter.create(op.getLoc(), +// rewriter.getI64Type(), command); +// changed = true; +// } +// if (!command.getType().isa()) { +// command = rewriter.create(op.getLoc(), +// rewriter.getIndexType(), command); +// changed = true; +// } + +// if (!changed) return failure(); +// mlir::Value res = rewriter.create(op.getLoc(), rewriter.getIndexType(), conn, command); +// rewriter.replaceOp(op, res); +// // if (op->getResultTypes()[0].isa()) { +// // rewriter.replaceOp(op, res); +// // } else { +// // rewriter.replaceOpWithNewOp(op, op->getResultTypes()[0], res); +// // } +// return success(changed); +// } +// }; + +// void ExecuteOp::getCanonicalizationPatterns(RewritePatternSet &results, +// MLIRContext *context) { +// results.insert(context); +// } + + +template +class UnparsedOpInnerCast final : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(ExecuteOp op, + LogicalResult matchAndRewrite(UnparsedOp op, PatternRewriter &rewriter) const override { - bool changed = false; - Value conn = op->getOperand(0); - Value command = op->getOperand(1); + Value input = op->getOperand(0); + + auto cst = input.getDefiningOp(); + if (!cst) return failure(); - if (conn.getType().isa() && command.getType().isa() && op->getResultTypes()[0].isa()) - return failure(); - - if (!conn.getType().isa()) { - conn = rewriter.create(op.getLoc(), - rewriter.getIndexType(), conn); - changed = true; - } - if (command.getType().isa()) { - command = rewriter.create(op.getLoc(), - LLVM::LLVMPointerType::get(rewriter.getI8Type()), command); - changed = true; - } - if (command.getType().isa()) { - command = rewriter.create(op.getLoc(), - rewriter.getI64Type(), command); - changed = true; - } - if (!command.getType().isa()) { - command = rewriter.create(op.getLoc(), - rewriter.getIndexType(), command); - changed = true; - } - - if (!changed) return failure(); - mlir::Value res = rewriter.create(op.getLoc(), rewriter.getIndexType(), conn, command); - rewriter.replaceOp(op, res); - // if (op->getResultTypes()[0].isa()) { - // rewriter.replaceOp(op, res); - // } else { - // rewriter.replaceOpWithNewOp(op, op->getResultTypes()[0], res); - // } - return success(changed); + rewriter.replaceOpWithNewOp(op, op.getType(), cst.getOperand()); + return success(); } }; -void ExecuteOp::getCanonicalizationPatterns(RewritePatternSet &results, +void UnparsedOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.insert(context); + results.insert >(context); } \ No newline at end of file diff --git a/lib/sql/Parser.cpp b/lib/sql/Parser.cpp new file mode 100644 index 000000000000..62bc4973df84 --- /dev/null +++ b/lib/sql/Parser.cpp @@ -0,0 +1,244 @@ +#include +#include +#include +#include + +#include "mlir/IR/Value.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/Attributes.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/BuiltinTypes.h" + + +#include "sql/SQLDialect.h" +#include "sql/SQLOps.h" +#include "sql/SQLTypes.h" + + +using namespace mlir; +using namespace sql; +enum class ParseType { + Nothing = 0, + Value = 1, + Attribute = 2, +}; + +struct ParseValue { +private: + ParseType ty; + Value value; + Attribute attr; +public: + ParseValue() : ty(ParseType::Nothing), value(nullptr), attr(nullptr) {} + ParseValue(Value value) : ty(ParseType::Value), value(value), attr(nullptr) { + assert(value); + } + ParseValue(Attribute attr) : ty(ParseType::Attribute), value(nullptr), attr(attr) {} + + ParseType getType() const { + return ty; + } + Value getValue() const { + int type_value = static_cast(ty); + // llvm::errs() << "ty: " << type_value << "\n"; + // llvm::errs() << "value: " << value << "\n"; + // llvm::errs() << "attr: " << attr << "\n"; + assert(ty == ParseType::Value); + assert(value); + return value; + } + Attribute getAttr() const { + assert(ty == ParseType::Attribute); + return attr; + } +}; + +enum class ParseMode { + None, + Column, + Table +}; + + +template +std::ostream& operator<<(typename std::enable_if::value, std::ostream>::type& stream, const T& e) +{ + return stream << static_cast::type>(e); +} + +class SQLParser { + +public: + Location loc; + OpBuilder &builder; + std::string sql; + unsigned int i; + + static std::vector reservedWords; + + + SQLParser(Location loc, OpBuilder &builder, std::string sql, int i) : loc(loc), builder(builder), + sql(sql), i(i) { + // llvm::errs() << "last three " << sql.substr(sql.size()-6, sql.size()) << "\n"; + // if (!sql.substr(sql.size()-2, sql.size()).compare("\\00")){ + // llvm::errs() << "triggers trim" << "\n"; + // sql = sql.substr(0, sql.size()-1); + // } + } + + std::string peek() { + auto [peeked, _] = peekWithLength(); + return peeked; + } + + std::string pop() { + auto [peeked, len] = peekWithLength(); + i += len; + popWhitespace(); + return peeked; + } + + void popWhitespace() { + // it doesn't recognize + while (i < sql.size() && (sql[i] == ' ' || sql[i] == '\n' || sql[i] == '\t')) { + i++; + } + } + + std::pair peekWithLength() { + if (i >= sql.size()) { + return {"", 0}; + } + for (std::string rWord : reservedWords) { + auto token = sql.substr(i, std::min(sql.size(), i + rWord.size())); + std::transform(token.begin(), token.end(), token.begin(), ::toupper); + if (token == rWord) { + return {token, static_cast(token.size())}; + } + } + if (sql[i] == '\'') { // Quoted string + return peekQuotedStringWithLength(); + } + return peekIdentifierWithLength(); + } + + std::pair peekQuotedStringWithLength() { + if (sql.size() < i || sql[i] != '\'') { + return {"", 0}; + } + for (unsigned int j = i + 1; j < sql.size(); j++) { + if (sql[j] == '\'' && sql[j-1] != '\\') { + return {sql.substr(i + 1, j - (i + 1)), j - i + 2}; // +2 for the two quotes + } + } + return {"", 0}; + } + + std::pair peekIdentifierWithLength() { + std::regex e("[a-zA-Z0-9_*]"); + for (unsigned int j = i; j < sql.size(); j++) { + if (!std::regex_match(std::string(1, sql[j]), e)) { + return {sql.substr(i, j - i), j - i}; + } + } + return {sql.substr(i), static_cast(sql.size())-i}; + } + + + bool is_number(std::string* s){ + std::string::const_iterator it = s->begin(); + while (it != s->end() && std::isdigit(*it)) ++it; + return !s->empty() && it == s->end(); + } + + + + // Parse the next command, if any + ParseValue parseNext(ParseMode mode) { + if (i >= sql.size()) { + llvm::errs() << "here i:" << i << "\n"; + llvm::errs() << "here size:" << sql.size() << "\n"; + return ParseValue(); + } + auto peekStr = peek(); + llvm::errs() << "peekStr: " << peekStr << "\n"; + assert(peekStr.size() > 0); + + if (peekStr == "SELECT") { + assert(mode == ParseMode::None || mode == ParseMode::Table); + pop(); + peekStr = peek(); + if (peekStr == "DISTINCT") { + pop(); + // do something different here + } + llvm::SmallVector columns; + bool hasColumns = true; + bool hasWhere = false; + Value table = nullptr; + while (true) { + peekStr = peek(); + // if (hasColumns) { + if (peekStr == "FROM") { + pop(); + table = parseNext(ParseMode::Table).getValue(); + hasColumns = false; + break; + } + ParseValue col = parseNext(ParseMode::Column); + if (col.getType() == ParseType::Nothing) { + hasColumns = false; + break; + } else { + columns.push_back(col.getValue()); + } + if (peekStr == ",") pop(); + // } else if (peekStr == "WHERE") { + // pop(); + // hasWhere = true; + // } else if (hasWhere){ + // // do something here + // break; + // } else { + // break; + // // assert(0 && " additional clauses like limit/etc not yet handled"); + // } + } + if (table) + return ParseValue(builder.create(loc, ExprType::get(builder.getContext()), columns, table).getResult()); + else + return ParseValue(builder.create(loc, ExprType::get(builder.getContext()), columns).getResult()); + } else if (is_number(&peekStr)){ + pop(); + return ParseValue(builder.create(loc, ExprType::get(builder.getContext()), builder.getStringAttr(peekStr)).getResult()); + } else if (mode == ParseMode::Column) { + // do we need this?? + // if (peekStr == "*") { + // pop(); + + // return ParseValue(builder.create(loc, ExprType::get(builder.getContext()), ); + // } + pop(); + return ParseValue(builder.create(loc, ExprType::get(builder.getContext()), builder.getStringAttr(peekStr)).getResult()); + } else if (mode == ParseMode::Table) { + pop(); + return ParseValue(builder.create(loc,ExprType::get(builder.getContext()) , builder.getStringAttr(peekStr)).getResult()); + } + llvm::errs() << " Unknown token to parse: " << peekStr << "\n"; + llvm_unreachable("Unknown token to parse"); + } + +}; + +std::vector SQLParser::reservedWords = { + "(", ")", ">=", "<=", "!=", ",", "=", ">", "<", "SELECT", "INSERT INTO", "VALUES", "UPDATE", "DELETE FROM", "WHERE", "FROM", "SET", "AS" +}; + + +mlir::Value parseSQL(mlir::Location loc, mlir::OpBuilder& builder, std::string str) { + SQLParser parser(loc, builder, std::string(str), 0); + auto resOp = parser.parseNext(ParseMode::None); + + return resOp.getValue(); +} \ No newline at end of file diff --git a/lib/sql/Passes/SQLRaising.cpp b/lib/sql/Passes/SQLRaising.cpp index 8133002a45e0..0c0352c06328 100644 --- a/lib/sql/Passes/SQLRaising.cpp +++ b/lib/sql/Passes/SQLRaising.cpp @@ -23,6 +23,7 @@ #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallPtrSet.h" #include "sql/SQLOps.h" +#include "sql/Parser.h" #include "sql/Passes/Passes.h" #include #include @@ -112,51 +113,83 @@ struct PQgetvalueRaising : public OpRewritePattern { }; -// struct PQexecRaising : public OpRewritePattern { -// using OpRewritePattern::OpRewritePattern; +struct PQexecRaising : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; -// LogicalResult matchAndRewrite(func::CallOp call, -// PatternRewriter &rewriter) const final { -// if (call.getCallee() != "PQexec") { -// return failure(); -// } -// SymbolTableCollection symbolTable; -// symbolTable.getSymbolTable(call); -// auto module = call->getParentOfType(); + LogicalResult matchAndRewrite(func::CallOp call, + PatternRewriter &rewriter) const final { + if (call.getCallee() != "PQexec") { + return failure(); + } -// // 2) convert the args to valid args to postgres_getresult abi -// Value conn = call.getArgOperands()[0]; -// conn = rewriter.create( -// call.getLoc(), rewriter.getIntegerType(64), conn); + // 2) convert the args to valid args to postgres_getresult abi + Value conn = call.getArgOperands()[0]; + if (!conn.getType().isa()) { + conn = rewriter.create( + call.getLoc(), LLVM::LLVMPointerType::get(rewriter.getI8Type()), conn); + } + conn = rewriter.create( + call.getLoc(), rewriter.getIntegerType(64), conn); -// conn = rewriter.create(call.getLoc(), -// rewriter.getIndexType(), conn); + conn = rewriter.create(call.getLoc(), + rewriter.getIndexType(), conn); -// Value command = call.getArgOperands()[1]; -// command = rewriter.create( -// call.getLoc(), rewriter.getIntegerType(64), command); + Value command = call.getArgOperands()[1]; -// command = rewriter.create(call.getLoc(), -// rewriter.getIndexType(), command); + command = rewriter.create(call.getLoc(), + ExprType::get(rewriter.getContext()), command); -// Value res = rewriter.create(call.getLoc(), rewriter.getIndexType(), conn, command); + Value res = rewriter.create(call.getLoc(), rewriter.getIndexType(), conn, command); -// res = rewriter.create(call.getLoc(), -// rewriter.getI64Type(), res); + res = rewriter.create(call.getLoc(), + rewriter.getI64Type(), res); + res = rewriter.create(call.getLoc(), LLVM::LLVMPointerType::get(rewriter.getI8Type()), res); + res = rewriter.create(call.getLoc(), call.getResultTypes()[0], res); + + + assert(call.getResultTypes()[0] == res.getType()); + rewriter.replaceOp(call, res); -// rewriter.replaceOp(call, res); -// /// rewriter.replaceOpWithNewOp(call, itoafn, res); + return success(); + } +}; -// // 4) done -// return success(); -// } -// }; + +class UnparsedOpParseFix final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(UnparsedOp op, + PatternRewriter &rewriter) const override { + + Value input = op->getOperand(0); + + auto stringOp = input.getDefiningOp(); + if (!stringOp) return failure(); + + StringRef strname = stringOp.getGlobalName(); + + auto module = op->getParentOfType(); + SymbolTableCollection symbolTable; + symbolTable.getSymbolTable(module); + + Attribute strattr = dyn_cast_or_null( + symbolTable.lookupSymbolIn(module, rewriter.getStringAttr(strname))).getValueAttr(); + auto str = strattr.cast().getValue(); + + auto resOp = parseSQL(op.getLoc(), rewriter, std::string(str.data())); + assert(resOp); + rewriter.replaceOp(op,ValueRange(resOp)); + return success(); + } +}; void SQLRaising::runOnOperation() { auto module = getOperation(); SymbolTableCollection symbolTable; symbolTable.getSymbolTable(module); - OpBuilder builder(module.getContext()); + auto &context = getContext(); + OpBuilder builder(&context); builder.setInsertionPointToStart(module.getBody()); if (!dyn_cast_or_null( @@ -169,12 +202,13 @@ void SQLRaising::runOnOperation() { SymbolTable::setSymbolVisibility(fn, SymbolTable::Visibility::Private); } + RewritePatternSet patterns(&context); + patterns.insert(&getContext()); - - RewritePatternSet patterns(&getContext()); - patterns.insert(&getContext()); - patterns.insert(&getContext()); - // patterns.insert(&getContext()); + for (auto *dialect : context.getLoadedDialects()) + dialect->getCanonicalizationPatterns(patterns); + for (RegisteredOperationName op : context.getRegisteredOperations()) + op.getCanonicalizationPatterns(patterns, &context); GreedyRewriteConfig config; (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), diff --git a/lib/sql/Types.cpp b/lib/sql/Types.cpp new file mode 100644 index 000000000000..56d682068602 --- /dev/null +++ b/lib/sql/Types.cpp @@ -0,0 +1,37 @@ +//===- SQLTypes.cpp - SQL dialect types ---------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" + +#include "sql/SQLDialect.h" +#include "llvm/ADT/TypeSwitch.h" +#include "mlir/IR/DialectImplementation.h" +#include "sql/SQLTypes.h" + + +#define DEBUG_TYPE "sql" + +using namespace mlir::sql; + +#define GET_TYPEDEF_CLASSES +#include "sql/SQLOpsTypes.cpp.inc" + + +void SQLDialect::registerTypes() { + addTypes< +#define GET_TYPEDEF_LIST +#include "sql/SQLOpsTypes.cpp.inc" + >(); +} \ No newline at end of file From 2b568ce0cb52c9b0858796b0236f56b10d091d0a Mon Sep 17 00:00:00 2001 From: Carl Guo Date: Fri, 29 Sep 2023 10:20:20 -0400 Subject: [PATCH 12/15] progress --- include/sql/SQLOps.td | 84 ++++++++++++++++++++--- include/sql/SQLTypes.td | 10 +++ lib/sql/Ops.cpp | 23 +++++++ lib/sql/Parser.cpp | 86 +++++++++++------------ lib/sql/Passes/SQLLower.cpp | 131 +++++++++++++++++++++++++++++++++--- 5 files changed, 273 insertions(+), 61 deletions(-) diff --git a/include/sql/SQLOps.td b/include/sql/SQLOps.td index 5e87f38f7d0c..94c4dc0a67f1 100644 --- a/include/sql/SQLOps.td +++ b/include/sql/SQLOps.td @@ -38,8 +38,8 @@ def ColumnOp : SQL_Op<"column", [Pure]> { let hasCanonicalizer = 0; } -def TableOp : SQL_Op<"table", [Pure]> { - let summary = "table"; +def AllColumnsOp : SQL_Op<"all_columns", [Pure]> { + let summary = "all columns op"; let arguments = (ins StrAttr:$expr); let results = (outs SQLExprType:$result); @@ -48,9 +48,49 @@ def TableOp : SQL_Op<"table", [Pure]> { let hasCanonicalizer = 0; } -def EmptyTableOp : SQL_Op<"empty_table", [Pure]> { - let summary = "empty_table"; - // i need to specify the size of a Variadic? +def WhereOp: SQL_Op<"where", [Pure]> { + let summary = "where op"; + + let arguments = (ins SQLBoolType:$expr); + let results = (outs SQLExprType:$result); + + let hasFolder = 0; + let hasCanonicalizer = 0; +} + +def CalcBoolOp: SQL_Op<"calc_bool", [Pure]> { + let summary = "calc_bool op"; + + let arguments = (ins StrAttr:$expr); + let results = (outs SQLBoolType:$result); + + let hasFolder = 0; + let hasCanonicalizer = 0; +} + +def AndOp: SQL_Op<"and", [Pure]> { + let summary = "and op"; + + let arguments = (ins Variadic:$expr); + let results = (outs SQLBoolType:$result); + + let hasFolder = 0; + let hasCanonicalizer = 0; +} + +def OrOp: SQL_Op<"or", [Pure]> { + let summary = "or op"; + + let arguments = (ins Variadic:$expr); + let results = (outs SQLBoolType:$result); + + let hasFolder = 0; + let hasCanonicalizer = 0; +} + +def TableOp : SQL_Op<"table", [Pure]> { + let summary = "table"; + let arguments = (ins StrAttr:$expr); let results = (outs SQLExprType:$result); @@ -58,10 +98,38 @@ def EmptyTableOp : SQL_Op<"empty_table", [Pure]> { let hasCanonicalizer = 0; } -// def VectorNotEmpty : AttrConstraint 0">, -// "VectorNotEmpty">; - +def SQLConstantStringOp : SQL_Op<"str_constant", [Pure]> { + let summary = "str_constant"; + + let arguments = (ins StrAttr:$input); + let results = (outs AnyType:$result); + + let hasFolder = 0; + let hasCanonicalizer = 0; +} + +def SQLToStringOp : SQL_Op<"to_string", [Pure]> { + let summary = "to_string"; + + let arguments = (ins SQLExprType:$expr); + let results = (outs AnyType:$result); + + let hasFolder = 0; + let hasCanonicalizer = 0; +} + + +def SQLStringConcatOp : SQL_Op<"string_concat", [Pure]> { + let summary = "string_concat"; + + let arguments = (ins Variadic:$expr); + let results = (outs AnyType:$result); + + let hasFolder = 0; + let hasCanonicalizer = 1; +} + def SelectOp : SQL_Op<"select", [Pure]> { let summary = "select"; // i need to specify the size of a Variadic? diff --git a/include/sql/SQLTypes.td b/include/sql/SQLTypes.td index 4774a2a07810..8d0800a62c5e 100644 --- a/include/sql/SQLTypes.td +++ b/include/sql/SQLTypes.td @@ -28,4 +28,14 @@ def SQLExprType : SQL_Type<"Expr", "expr"> { // let assemblyFormat = "`<` $value `>`"; } + +def SQLBoolType : SQL_Type<"Bool", "bool"> { + let summary = "SQL boolean type"; + let description = "Custom attr or value type in sql dialect"; + + // placeholder params + // let parameters = (ins StringRefParameter<"the custom value">:$value); + // let assemblyFormat = "`<` $value `>`"; +} + #endif // SQL_TYPES \ No newline at end of file diff --git a/lib/sql/Ops.cpp b/lib/sql/Ops.cpp index 96116ad193cf..bfc4c28200c1 100644 --- a/lib/sql/Ops.cpp +++ b/lib/sql/Ops.cpp @@ -212,4 +212,27 @@ class UnparsedOpInnerCast final : public OpRewritePattern { void UnparsedOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.insert >(context); +} + + +class SQLStringConcatOpCanonicalization final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(SQLStringConcatOp op, + PatternRewriter &rewriter) const override { + + auto input1 = op->getOperand(0).getDefiningOp(); + auto input2 = op->getOperand(1).getDefiningOp(); + + if (!input1 || !input2) return failure(); + + rewriter.replaceOpWithNewOp(op, op.getType(), (input1.getInput() + input2.getInput()).str()); + return success(); + } +}; + +void SQLStringConcatOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.insert(context); } \ No newline at end of file diff --git a/lib/sql/Parser.cpp b/lib/sql/Parser.cpp index 62bc4973df84..6b089bd2b4b3 100644 --- a/lib/sql/Parser.cpp +++ b/lib/sql/Parser.cpp @@ -40,7 +40,6 @@ struct ParseValue { return ty; } Value getValue() const { - int type_value = static_cast(ty); // llvm::errs() << "ty: " << type_value << "\n"; // llvm::errs() << "value: " << value << "\n"; // llvm::errs() << "attr: " << attr << "\n"; @@ -57,7 +56,8 @@ struct ParseValue { enum class ParseMode { None, Column, - Table + Table, + Bool, }; @@ -80,11 +80,6 @@ class SQLParser { SQLParser(Location loc, OpBuilder &builder, std::string sql, int i) : loc(loc), builder(builder), sql(sql), i(i) { - // llvm::errs() << "last three " << sql.substr(sql.size()-6, sql.size()) << "\n"; - // if (!sql.substr(sql.size()-2, sql.size()).compare("\\00")){ - // llvm::errs() << "triggers trim" << "\n"; - // sql = sql.substr(0, sql.size()-1); - // } } std::string peek() { @@ -179,52 +174,59 @@ class SQLParser { Value table = nullptr; while (true) { peekStr = peek(); - // if (hasColumns) { - if (peekStr == "FROM") { - pop(); - table = parseNext(ParseMode::Table).getValue(); - hasColumns = false; + if (hasColumns) { + if (peekStr == "FROM") { + pop(); + table = parseNext(ParseMode::Table).getValue(); + hasColumns = false; + break; + } + ParseValue col = parseNext(ParseMode::Column); + if (col.getType() == ParseType::Nothing) { + hasColumns = false; + break; + } else { + columns.push_back(col.getValue()); + } + if (peekStr == ",") pop(); + } else if (peekStr == "WHERE") { + pop(); + hasWhere = true; + } else { break; + // assert(0 && " additional clauses like limit/etc not yet handled"); } - ParseValue col = parseNext(ParseMode::Column); - if (col.getType() == ParseType::Nothing) { - hasColumns = false; - break; - } else { - columns.push_back(col.getValue()); - } - if (peekStr == ",") pop(); - // } else if (peekStr == "WHERE") { - // pop(); - // hasWhere = true; - // } else if (hasWhere){ - // // do something here - // break; - // } else { - // break; - // // assert(0 && " additional clauses like limit/etc not yet handled"); - // } } - if (table) - return ParseValue(builder.create(loc, ExprType::get(builder.getContext()), columns, table).getResult()); - else - return ParseValue(builder.create(loc, ExprType::get(builder.getContext()), columns).getResult()); + if (!table) + table = builder.create(loc, ExprType::get(builder.getContext()), builder.getStringAttr("")).getResult(); + return ParseValue(builder.create(loc, ExprType::get(builder.getContext()), columns, table).getResult()); } else if (is_number(&peekStr)){ pop(); return ParseValue(builder.create(loc, ExprType::get(builder.getContext()), builder.getStringAttr(peekStr)).getResult()); } else if (mode == ParseMode::Column) { // do we need this?? - // if (peekStr == "*") { - // pop(); - - // return ParseValue(builder.create(loc, ExprType::get(builder.getContext()), ); - // } + if (peekStr == "*") { + pop(); + return ParseValue(builder.create(loc, ExprType::get(builder.getContext()), builder.getStringAttr(peekStr)).getResult()); + } pop(); return ParseValue(builder.create(loc, ExprType::get(builder.getContext()), builder.getStringAttr(peekStr)).getResult()); } else if (mode == ParseMode::Table) { pop(); - return ParseValue(builder.create(loc,ExprType::get(builder.getContext()) , builder.getStringAttr(peekStr)).getResult()); - } + return ParseValue(builder.create(loc,ExprType::get(builder.getContext()), builder.getStringAttr(peekStr)).getResult()); + } else if (mode == ParseMode::Bool) { + // col = peekStr; + pop(); + + } else if (peekStr == "(") { + pop(); + ParseValue res = parseNext(ParseMode::None); + assert(peek() == ")"); + pop(); + return res; + } else if (peekStr == ")") { + return ParseValue(); + } llvm::errs() << " Unknown token to parse: " << peekStr << "\n"; llvm_unreachable("Unknown token to parse"); } @@ -232,7 +234,7 @@ class SQLParser { }; std::vector SQLParser::reservedWords = { - "(", ")", ">=", "<=", "!=", ",", "=", ">", "<", "SELECT", "INSERT INTO", "VALUES", "UPDATE", "DELETE FROM", "WHERE", "FROM", "SET", "AS" + "(", ")", ">=", "<=", "!=", ",", "=", ">", "<", "SELECT", "DISTINCT", "INSERT INTO", "VALUES", "UPDATE", "DELETE FROM", "WHERE", "FROM", "SET", "AS" }; diff --git a/lib/sql/Passes/SQLLower.cpp b/lib/sql/Passes/SQLLower.cpp index 9dc30f9f75ba..76343c169b48 100644 --- a/lib/sql/Passes/SQLLower.cpp +++ b/lib/sql/Passes/SQLLower.cpp @@ -151,6 +151,108 @@ struct GetValueOpLowering : public OpRewritePattern { } }; + +struct ConstantStringOpLowering : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(sql::SQLConstantStringOp op, + PatternRewriter &rewriter) const final { + auto module = op->getParentOfType(); + + SymbolTableCollection symbolTable; + symbolTable.getSymbolTable(module); + + auto expr = (op.getInput() + "\0").str(); + auto name = "str" + std::to_string((long long int)(Operation *)op); + auto MT = MemRefType::get({expr.size()}, rewriter.getI8Type()); + auto getglob = rewriter.create(op.getLoc(), MT, name); + + rewriter.setInsertionPointToStart(module.getBody()); + auto res = rewriter.create(op.getLoc(), rewriter.getStringAttr(name), + mlir::StringAttr(), mlir::TypeAttr::get(MT), rewriter.getStringAttr(expr), mlir::UnitAttr(), /*alignment*/nullptr); + rewriter.replaceOpWithNewOp(op, MemRefType::get({-1}, rewriter.getI8Type()), getglob.getResult()); + return success(); + } +}; + +// struct StringConcatOpLowering : public OpRewritePattern { +// using OpRewritePattern::OpRewritePattern; + +// LogicalResult matchAndRewrite(sql::SQLStringConcatOp op, +// PatternRewriter &rewriter) const final { +// auto module = op->getParentOfType(); + +// SymbolTableCollection symbolTable; +// symbolTable.getSymbolTable(module); + +// return success(); +// } +// }; + + +struct ToStringOpLowering : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(sql::SQLToStringOp op, + PatternRewriter &rewriter) const final { + auto module = op->getParentOfType(); + + SymbolTableCollection symbolTable; + symbolTable.getSymbolTable(module); + + // 2) convert the args to valid args to postgres_getresult abi + Value expr = op.getExpr(); + auto definingOp = expr.getDefiningOp(); + if (auto selectOp = dyn_cast(definingOp)){ + Value current = rewriter.create(op.getLoc(), + MemRefType::get({-1}, rewriter.getI8Type()), "SELECT "); + bool prevColumn = false; + for (auto v : selectOp.getColumns()) { + Value columns = rewriter.create(op.getLoc(), + MemRefType::get({-1}, rewriter.getI8Type()), v); + Value args[] = { current, columns }; + current = rewriter.create(op.getLoc(), + MemRefType::get({-1}, rewriter.getI8Type()),args); + if (prevColumn) { + Value args[] = { current, rewriter.create(op.getLoc(), + MemRefType::get({-1}, rewriter.getI8Type()), ", ") }; + current = rewriter.create(op.getLoc(), + MemRefType::get({-1}, rewriter.getI8Type()),args); + } + prevColumn = true; + } + auto tableOp = selectOp.getTable().getDefiningOp(); + if (!tableOp || !tableOp.getExpr().empty()) { + Value args[] = { current, rewriter.create(op.getLoc(), + MemRefType::get({-1}, rewriter.getI8Type()), "FROM ") }; + current = rewriter.create(op.getLoc(), + MemRefType::get({-1}, rewriter.getI8Type()),args); + Value args2[] = { current, rewriter.create(op.getLoc(), + MemRefType::get({-1}, rewriter.getI8Type()), selectOp.getTable()) }; + current = rewriter.create(op.getLoc(), + MemRefType::get({-1}, rewriter.getI8Type()),args2); + } + rewriter.replaceOp(op, current); + } else if (auto tabOp = dyn_cast(definingOp)) { + Value res = rewriter.create(op.getLoc(), + MemRefType::get({-1}, rewriter.getI8Type()), tabOp.getExpr()); + rewriter.replaceOp(op, res); + } else if (auto colOp = dyn_cast(definingOp)) { + Value res = rewriter.create(op.getLoc(), + MemRefType::get({-1}, rewriter.getI8Type()), colOp.getExpr()); + rewriter.replaceOp(op, res); + } else if (auto intOp = dyn_cast(definingOp)){ + Value res = rewriter.create(op.getLoc(), + MemRefType::get({-1}, rewriter.getI8Type()), intOp.getExpr()); + rewriter.replaceOp(op, res); + } else { + assert(0 && "unknown type to convert to string"); + } + + return success(); + } +}; + struct ExecuteOpLowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -175,17 +277,16 @@ struct ExecuteOpLowering : public OpRewritePattern { executefn.getFunctionType().getInput(0), conn); Value command = op.getCommand(); - command = rewriter.create(op.getLoc(), - rewriter.getI8Type(), command); - command = rewriter.create( - op.getLoc(), LLVM::LLVMPointerType::get(rewriter.getI8Type()), command); - StringRef strname = command.getDefiningOp().getGlobalName(); - Attribute strattr = dyn_cast_or_null( - symbolTable.lookupSymbolIn(module, rewriter.getStringAttr(strname))).getValueAttr(); - auto str = strattr.cast().getValue(); - llvm::errs() << str << "\n"; - command = rewriter.create(op.getLoc(), - executefn.getFunctionType().getInput(1), command); + // auto name = "str" + std::to_string((long long int)(Operation *)command.getDefiningOp()); + command = rewriter.create(op.getLoc(), + MemRefType::get({-1}, rewriter.getI8Type()), command); + // auto type = MemRefType::get({-1}, rewriter.getI8Type()); + + + // auto globalOp = rewriter.create( + // op.getLoc(), type, /* isConstant */ true, LLVM::Linkage::Internal, name, + // mlir::Attribute(), + // /* alignment */ 0, /* addrSpace */ 0); // 3) call and replace @@ -262,6 +363,14 @@ void SQLLower::runOnOperation() { patterns.insert(&getContext()); patterns.insert(&getContext()); patterns.insert(&getContext()); + patterns.insert(&getContext()); + patterns.insert(&getContext()); + + + for (auto *dialect : getContext().getLoadedDialects()) + dialect->getCanonicalizationPatterns(patterns); + for (RegisteredOperationName op : getContext().getRegisteredOperations()) + op.getCanonicalizationPatterns(patterns, &getContext()); GreedyRewriteConfig config; (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), From 5249dbe8c99bb3d875ec71d6d17e617bf930618a Mon Sep 17 00:00:00 2001 From: Carl Guo Date: Thu, 26 Oct 2023 14:03:03 -0400 Subject: [PATCH 13/15] basic select --- include/sql/SQLOps.td | 35 +++++++++++++------ include/sql/SQLTypes.td | 1 - lib/sql/Parser.cpp | 51 ++++++++++++++++++++++------ lib/sql/Passes/SQLLower.cpp | 68 +++++++++++++++++++++++++++---------- 4 files changed, 116 insertions(+), 39 deletions(-) diff --git a/include/sql/SQLOps.td b/include/sql/SQLOps.td index 94c4dc0a67f1..2f49e0a9cffa 100644 --- a/include/sql/SQLOps.td +++ b/include/sql/SQLOps.td @@ -38,16 +38,6 @@ def ColumnOp : SQL_Op<"column", [Pure]> { let hasCanonicalizer = 0; } -def AllColumnsOp : SQL_Op<"all_columns", [Pure]> { - let summary = "all columns op"; - - let arguments = (ins StrAttr:$expr); - let results = (outs SQLExprType:$result); - - let hasFolder = 0; - let hasCanonicalizer = 0; -} - def WhereOp: SQL_Op<"where", [Pure]> { let summary = "where op"; @@ -130,16 +120,39 @@ def SQLStringConcatOp : SQL_Op<"string_concat", [Pure]> { let hasCanonicalizer = 1; } +def ConstantBoolOp : SQL_Op <"constant_bool", [Pure]> { + let summary = "constant_bool"; + let results = (outs SQLBoolType:$result); + +} + + def SelectOp : SQL_Op<"select", [Pure]> { let summary = "select"; // i need to specify the size of a Variadic? - let arguments = (ins Variadic:$columns, SQLExprType:$table); + let arguments = (ins Variadic:$columns, + SQLExprType:$table, + // SQLBoolType:$where, + BoolAttr:$selectAll + IntAttr:$limit); + // attribute limit if >= 0 then its the real thing, otherwise its infinity + let results = (outs SQLExprType:$result); + + let hasFolder = 0; + let hasCanonicalizer = 0; +} + +def AllColumnsOp : SQL_Op<"all_columns", [Pure]> { + let summary = "all_columns"; let results = (outs SQLExprType:$result); let hasFolder = 0; let hasCanonicalizer = 0; } + + + def UnparsedOp : SQL_Op<"unparsed", [Pure]> { let summary = "unparsed sql op"; diff --git a/include/sql/SQLTypes.td b/include/sql/SQLTypes.td index 8d0800a62c5e..a8fc3f0e30d6 100644 --- a/include/sql/SQLTypes.td +++ b/include/sql/SQLTypes.td @@ -28,7 +28,6 @@ def SQLExprType : SQL_Type<"Expr", "expr"> { // let assemblyFormat = "`<` $value `>`"; } - def SQLBoolType : SQL_Type<"Bool", "bool"> { let summary = "SQL boolean type"; let description = "Custom attr or value type in sql dialect"; diff --git a/lib/sql/Parser.cpp b/lib/sql/Parser.cpp index 6b089bd2b4b3..25a4776368e2 100644 --- a/lib/sql/Parser.cpp +++ b/lib/sql/Parser.cpp @@ -106,7 +106,7 @@ class SQLParser { return {"", 0}; } for (std::string rWord : reservedWords) { - auto token = sql.substr(i, std::min(sql.size(), i + rWord.size())); + auto token = sql.substr(i, std::min(sql.size()-i, rWord.size())); std::transform(token.begin(), token.end(), token.begin(), ::toupper); if (token == rWord) { return {token, static_cast(token.size())}; @@ -148,12 +148,14 @@ class SQLParser { } - // Parse the next command, if any ParseValue parseNext(ParseMode mode) { + // for (unsigned int j = i; j < sql.size(); j++) { + // auto peekStr = peek(); + // pop(); + // llvm::errs() << "peekStrTest: " << i << " " << peekStr << "\n"; + // } if (i >= sql.size()) { - llvm::errs() << "here i:" << i << "\n"; - llvm::errs() << "here size:" << sql.size() << "\n"; return ParseValue(); } auto peekStr = peek(); @@ -171,6 +173,8 @@ class SQLParser { llvm::SmallVector columns; bool hasColumns = true; bool hasWhere = false; + bool selectAll = false; + int limit = -1; Value table = nullptr; while (true) { peekStr = peek(); @@ -178,9 +182,20 @@ class SQLParser { if (peekStr == "FROM") { pop(); table = parseNext(ParseMode::Table).getValue(); + llvm::errs() << "table: " << table << "\n"; hasColumns = false; break; } + if (peekStr == "*") { + pop(); + selectAll = true; + continue; + } + if (peekStr == ",") { + pop(); + llvm::errs() << "comma\n"; + continue; + } ParseValue col = parseNext(ParseMode::Column); if (col.getType() == ParseType::Nothing) { hasColumns = false; @@ -188,18 +203,34 @@ class SQLParser { } else { columns.push_back(col.getValue()); } - if (peekStr == ",") pop(); + } else if (peekStr == "WHERE") { pop(); hasWhere = true; + } else if (peekStr == "LIMIT"){ + pop(); + peekStr = peek(); + if (peekStr == "ALL"){ + pop(); + } else if (is_number(&peekStr)){ + pop(); + limit = std::stoi(peekStr); + } } else { - break; - // assert(0 && " additional clauses like limit/etc not yet handled"); + // break; + assert(0 && " additional clauses like where/etc not yet handled"); } } - if (!table) + if (!table){ + llvm::errs() << " table is null: " << table << "\n"; table = builder.create(loc, ExprType::get(builder.getContext()), builder.getStringAttr("")).getResult(); - return ParseValue(builder.create(loc, ExprType::get(builder.getContext()), columns, table).getResult()); + } + // if (selectAll){ + // assert(table && "table cannot be null"); + // return ParseValue(builder.create(loc, ExprType::get(builder.getContext()), table).getResult()); + // } else { + return ParseValue(builder.create(loc, ExprType::get(builder.getContext()), columns, table, selectALl, limit).getResult()); + // } } else if (is_number(&peekStr)){ pop(); return ParseValue(builder.create(loc, ExprType::get(builder.getContext()), builder.getStringAttr(peekStr)).getResult()); @@ -234,7 +265,7 @@ class SQLParser { }; std::vector SQLParser::reservedWords = { - "(", ")", ">=", "<=", "!=", ",", "=", ">", "<", "SELECT", "DISTINCT", "INSERT INTO", "VALUES", "UPDATE", "DELETE FROM", "WHERE", "FROM", "SET", "AS" + "(", ")", ">=", "<=", "!=", ",", "=", ">", "<", ",", "SELECT", "DISTINCT", "INSERT INTO", "VALUES", "UPDATE", "DELETE FROM", "WHERE", "FROM", "SET", "AS" }; diff --git a/lib/sql/Passes/SQLLower.cpp b/lib/sql/Passes/SQLLower.cpp index 76343c169b48..eaa26ddb2458 100644 --- a/lib/sql/Passes/SQLLower.cpp +++ b/lib/sql/Passes/SQLLower.cpp @@ -161,16 +161,26 @@ struct ConstantStringOpLowering : public OpRewritePattern(u)) return failure(); + } + auto expr = op.getInput().str(); auto name = "str" + std::to_string((long long int)(Operation *)op); - auto MT = MemRefType::get({expr.size()}, rewriter.getI8Type()); + auto MT = MemRefType::get({expr.size() + 1}, rewriter.getI8Type()); + // auto type = MemRefType::get(mt.getShape(), mt.getElementType(), {}); auto getglob = rewriter.create(op.getLoc(), MT, name); + + SmallVector data(expr.begin(), expr.end()); + data.push_back('\0'); + auto attr = DenseElementsAttr::get( + RankedTensorType::get(MT.getShape(), MT.getElementType()), data); - rewriter.setInsertionPointToStart(module.getBody()); - auto res = rewriter.create(op.getLoc(), rewriter.getStringAttr(name), - mlir::StringAttr(), mlir::TypeAttr::get(MT), rewriter.getStringAttr(expr), mlir::UnitAttr(), /*alignment*/nullptr); + auto loc = op.getLoc(); rewriter.replaceOpWithNewOp(op, MemRefType::get({-1}, rewriter.getI8Type()), getglob.getResult()); + rewriter.setInsertionPointToStart(module.getBody()); + auto res = rewriter.create(loc, rewriter.getStringAttr(name), + mlir::StringAttr(), mlir::TypeAttr::get(MT), attr, rewriter.getUnitAttr(), /*alignment*/nullptr); + return success(); } }; @@ -207,22 +217,27 @@ struct ToStringOpLowering : public OpRewritePattern { Value current = rewriter.create(op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), "SELECT "); bool prevColumn = false; - for (auto v : selectOp.getColumns()) { - Value columns = rewriter.create(op.getLoc(), - MemRefType::get({-1}, rewriter.getI8Type()), v); - Value args[] = { current, columns }; - current = rewriter.create(op.getLoc(), - MemRefType::get({-1}, rewriter.getI8Type()),args); + auto columns = selectOp.getColumns(); + for (mlir::Value v : columns) { if (prevColumn) { - Value args[] = { current, rewriter.create(op.getLoc(), - MemRefType::get({-1}, rewriter.getI8Type()), ", ") }; - current = rewriter.create(op.getLoc(), - MemRefType::get({-1}, rewriter.getI8Type()),args); + Value args[] = { current, rewriter.create(op.getLoc(), + MemRefType::get({-1}, rewriter.getI8Type()), ", ") }; + current = rewriter.create(op.getLoc(), + MemRefType::get({-1}, rewriter.getI8Type()),args); } + Value col = rewriter.create(op.getLoc(), + MemRefType::get({-1}, rewriter.getI8Type()), v); + Value args[] = { col, rewriter.create(op.getLoc(), + MemRefType::get({-1}, rewriter.getI8Type()), " ")}; + col = rewriter.create(op.getLoc(), + MemRefType::get({-1}, rewriter.getI8Type()), args); + Value args2[] = { current, col }; + current = rewriter.create(op.getLoc(), + MemRefType::get({-1}, rewriter.getI8Type()), args2); prevColumn = true; } auto tableOp = selectOp.getTable().getDefiningOp(); - if (!tableOp || !tableOp.getExpr().empty()) { + if (tableOp) { Value args[] = { current, rewriter.create(op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), "FROM ") }; current = rewriter.create(op.getLoc(), @@ -233,6 +248,16 @@ struct ToStringOpLowering : public OpRewritePattern { MemRefType::get({-1}, rewriter.getI8Type()),args2); } rewriter.replaceOp(op, current); + } else if (auto selectAllOp = dyn_cast(definingOp)){ + auto table = rewriter.create(op.getLoc(), + MemRefType::get({-1}, rewriter.getI8Type()), selectAllOp.getTable()); + Value res = rewriter.create(op.getLoc(), + MemRefType::get({-1}, rewriter.getI8Type()), "SELECT * FROM "); + Value args[] = { res, table }; + res = rewriter.create(op.getLoc(), + MemRefType::get({-1}, rewriter.getI8Type()),args); + + rewriter.replaceOp(op, res); } else if (auto tabOp = dyn_cast(definingOp)) { Value res = rewriter.create(op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), tabOp.getExpr()); @@ -244,6 +269,7 @@ struct ToStringOpLowering : public OpRewritePattern { } else if (auto intOp = dyn_cast(definingOp)){ Value res = rewriter.create(op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), intOp.getExpr()); + llvm::errs() << "intOp: " << intOp.getExpr() << "\n"; rewriter.replaceOp(op, res); } else { assert(0 && "unknown type to convert to string"); @@ -280,6 +306,8 @@ struct ExecuteOpLowering : public OpRewritePattern { // auto name = "str" + std::to_string((long long int)(Operation *)command.getDefiningOp()); command = rewriter.create(op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), command); + llvm::errs() << "command: " << command << "\n"; + llvm::errs() << "command type: " << command.getType() << "\n"; // auto type = MemRefType::get({-1}, rewriter.getI8Type()); @@ -295,6 +323,12 @@ struct ExecuteOpLowering : public OpRewritePattern { Value res = rewriter.create(op.getLoc(), executefn, args) ->getResult(0); + res = rewriter.create(op.getLoc(), + LLVM::LLVMPointerType::get(rewriter.getI8Type()), res); + res = rewriter.create( + op.getLoc(), rewriter.getI64Type(), res); + res = rewriter.create(op.getLoc(), + op.getType(), res); rewriter.replaceOp(op, res); From 0a26185c68651ccac6b6f0185ca327044b28f7df Mon Sep 17 00:00:00 2001 From: Carl Guo Date: Tue, 14 Nov 2023 13:24:39 -0500 Subject: [PATCH 14/15] merge this before using new parser --- include/sql/SQLOps.td | 22 +- lib/sql/Ops.cpp | 183 +++++++++----- lib/sql/Parser.cpp | 481 ++++++++++++++++++++---------------- lib/sql/Passes/SQLLower.cpp | 449 ++++++++++++++++++++------------- 4 files changed, 675 insertions(+), 460 deletions(-) diff --git a/include/sql/SQLOps.td b/include/sql/SQLOps.td index 2f49e0a9cffa..eafeaac1c5d5 100644 --- a/include/sql/SQLOps.td +++ b/include/sql/SQLOps.td @@ -14,6 +14,7 @@ include "SQLDialect.td" include "SQLTypes.td" + class SQL_Op traits = []> : Op; @@ -61,7 +62,7 @@ def CalcBoolOp: SQL_Op<"calc_bool", [Pure]> { def AndOp: SQL_Op<"and", [Pure]> { let summary = "and op"; - let arguments = (ins Variadic:$expr); + let arguments = (ins SQLBoolType:$left, SQLBoolType:$right); let results = (outs SQLBoolType:$result); let hasFolder = 0; @@ -71,7 +72,7 @@ def AndOp: SQL_Op<"and", [Pure]> { def OrOp: SQL_Op<"or", [Pure]> { let summary = "or op"; - let arguments = (ins Variadic:$expr); + let arguments = (ins SQLBoolType:$left, SQLBoolType:$right); let results = (outs SQLBoolType:$result); let hasFolder = 0; @@ -110,6 +111,17 @@ def SQLToStringOp : SQL_Op<"to_string", [Pure]> { } +def SQLBoolToStringOp : SQL_Op<"bool_to_string", [Pure]> { + let summary = "bool_to_string"; + + let arguments = (ins SQLBoolType:$expr); + let results = (outs AnyType:$result); + + let hasFolder = 0; + let hasCanonicalizer = 0; +} + + def SQLStringConcatOp : SQL_Op<"string_concat", [Pure]> { let summary = "string_concat"; @@ -123,7 +135,6 @@ def SQLStringConcatOp : SQL_Op<"string_concat", [Pure]> { def ConstantBoolOp : SQL_Op <"constant_bool", [Pure]> { let summary = "constant_bool"; let results = (outs SQLBoolType:$result); - } @@ -132,9 +143,8 @@ def SelectOp : SQL_Op<"select", [Pure]> { // i need to specify the size of a Variadic? let arguments = (ins Variadic:$columns, SQLExprType:$table, - // SQLBoolType:$where, - BoolAttr:$selectAll - IntAttr:$limit); + SQLExprType:$where, + SI64Attr:$limit); // attribute limit if >= 0 then its the real thing, otherwise its infinity let results = (outs SQLExprType:$result); diff --git a/lib/sql/Ops.cpp b/lib/sql/Ops.cpp index bfc4c28200c1..4f2c2b34cddb 100644 --- a/lib/sql/Ops.cpp +++ b/lib/sql/Ops.cpp @@ -5,10 +5,10 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// +#include +#include #include #include -#include -#include #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" @@ -18,11 +18,11 @@ #include "mlir/IR/OpImplementation.h" #include "mlir/Interfaces/SideEffectInterfaces.h" +#include "polygeist/Ops.h" +#include "sql/Parser.h" #include "sql/SQLDialect.h" #include "sql/SQLOps.h" #include "sql/SQLTypes.h" -#include "sql/Parser.h" -#include "polygeist/Ops.h" #define GET_OP_CLASSES #include "sql/SQLOps.cpp.inc" @@ -41,13 +41,12 @@ #include "llvm/ADT/SetVector.h" #include "llvm/Support/Debug.h" - -#include "mlir/IR/Value.h" +#include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Location.h" -#include "mlir/IR/Attributes.h" +#include "mlir/IR/Value.h" #include "llvm/ADT/SmallVector.h" -#include "mlir/IR/BuiltinTypes.h" #define DEBUG_TYPE "sql" @@ -55,7 +54,6 @@ using namespace mlir; using namespace sql; using namespace mlir::arith; - class GetValueOpTypeFix final : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -67,38 +65,38 @@ class GetValueOpTypeFix final : public OpRewritePattern { Value handle = op.getOperand(0); if (!handle.getType().isa()) { - handle = rewriter.create(op.getLoc(), - rewriter.getIndexType(), handle); - changed = true; + handle = rewriter.create(op.getLoc(), + rewriter.getIndexType(), handle); + changed = true; } Value row = op.getOperand(1); if (!row.getType().isa()) { - row = rewriter.create(op.getLoc(), - rewriter.getIndexType(), row); - changed = true; + row = rewriter.create(op.getLoc(), rewriter.getIndexType(), + row); + changed = true; } Value column = op.getOperand(2); if (!column.getType().isa()) { - column = rewriter.create(op.getLoc(), - rewriter.getIndexType(), column); - changed = true; + column = rewriter.create(op.getLoc(), + rewriter.getIndexType(), column); + changed = true; } - if (!changed) return failure(); + if (!changed) + return failure(); - rewriter.replaceOpWithNewOp(op, op.getType(), handle, row, column); + rewriter.replaceOpWithNewOp(op, op.getType(), handle, row, + column); return success(changed); } }; void GetValueOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { + MLIRContext *context) { results.insert(context); } - - class NumResultsOpTypeFix final : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -108,21 +106,24 @@ class NumResultsOpTypeFix final : public OpRewritePattern { bool changed = false; Value handle = op->getOperand(0); - if (handle.getType().isa() && op->getResultTypes()[0].isa()) - return failure(); + if (handle.getType().isa() && + op->getResultTypes()[0].isa()) + return failure(); if (!handle.getType().isa()) { - handle = rewriter.create(op.getLoc(), - rewriter.getIndexType(), handle); - changed = true; + handle = rewriter.create(op.getLoc(), + rewriter.getIndexType(), handle); + changed = true; } - mlir::Value res = rewriter.create(op.getLoc(), rewriter.getIndexType(), handle); + mlir::Value res = rewriter.create( + op.getLoc(), rewriter.getIndexType(), handle); if (op->getResultTypes()[0].isa()) { - rewriter.replaceOp(op, res); + rewriter.replaceOp(op, res); } else { - rewriter.replaceOpWithNewOp(op, op->getResultTypes()[0], res); + rewriter.replaceOpWithNewOp(op, op->getResultTypes()[0], + res); } return success(changed); @@ -130,12 +131,10 @@ class NumResultsOpTypeFix final : public OpRewritePattern { }; void NumResultsOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { + MLIRContext *context) { results.insert(context); } - - // class ExecuteOpTypeFix final : public OpRewritePattern { // public: // using OpRewritePattern::OpRewritePattern; @@ -147,39 +146,44 @@ void NumResultsOp::getCanonicalizationPatterns(RewritePatternSet &results, // Value conn = op->getOperand(0); // Value command = op->getOperand(1); -// if (conn.getType().isa() && command.getType().isa() && op->getResultTypes()[0].isa()) +// if (conn.getType().isa() && command.getType().isa() +// && op->getResultTypes()[0].isa()) // return failure(); // if (!conn.getType().isa()) { // conn = rewriter.create(op.getLoc(), -// rewriter.getIndexType(), conn); +// rewriter.getIndexType(), +// conn); // changed = true; // } // if (command.getType().isa()) { -// command = rewriter.create(op.getLoc(), -// LLVM::LLVMPointerType::get(rewriter.getI8Type()), command); +// command = rewriter.create(op.getLoc(), +// LLVM::LLVMPointerType::get(rewriter.getI8Type()), +// command); // changed = true; // } - // if (command.getType().isa()) { -// command = rewriter.create(op.getLoc(), -// rewriter.getI64Type(), command); +// command = rewriter.create(op.getLoc(), +// rewriter.getI64Type(), +// command); // changed = true; // } // if (!command.getType().isa()) { -// command = rewriter.create(op.getLoc(), -// rewriter.getIndexType(), command); +// command = rewriter.create(op.getLoc(), +// rewriter.getIndexType(), +// command); // changed = true; // } // if (!changed) return failure(); -// mlir::Value res = rewriter.create(op.getLoc(), rewriter.getIndexType(), conn, command); -// rewriter.replaceOp(op, res); +// mlir::Value res = rewriter.create(op.getLoc(), +// rewriter.getIndexType(), conn, command); rewriter.replaceOp(op, res); // // if (op->getResultTypes()[0].isa()) { // // rewriter.replaceOp(op, res); // // } else { -// // rewriter.replaceOpWithNewOp(op, op->getResultTypes()[0], res); +// // rewriter.replaceOpWithNewOp(op, +// op->getResultTypes()[0], res); // // } // return success(changed); // } @@ -190,8 +194,7 @@ void NumResultsOp::getCanonicalizationPatterns(RewritePatternSet &results, // results.insert(context); // } - -template +template class UnparsedOpInnerCast final : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -200,9 +203,10 @@ class UnparsedOpInnerCast final : public OpRewritePattern { PatternRewriter &rewriter) const override { Value input = op->getOperand(0); - + auto cst = input.getDefiningOp(); - if (!cst) return failure(); + if (!cst) + return failure(); rewriter.replaceOpWithNewOp(op, op.getType(), cst.getOperand()); return success(); @@ -210,29 +214,80 @@ class UnparsedOpInnerCast final : public OpRewritePattern { }; void UnparsedOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.insert >(context); + MLIRContext *context) { + results.insert>(context); } - -class SQLStringConcatOpCanonicalization final : public OpRewritePattern { +class SQLStringConcatOpCanonicalization final + : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(SQLStringConcatOp op, PatternRewriter &rewriter) const override { - - auto input1 = op->getOperand(0).getDefiningOp(); - auto input2 = op->getOperand(1).getDefiningOp(); - - if (!input1 || !input2) return failure(); - - rewriter.replaceOpWithNewOp(op, op.getType(), (input1.getInput() + input2.getInput()).str()); - return success(); + // Whether we changed the state. If we make no simplifications we need to + // return failure otherwise we will infinite loop + bool changed = false; + // Operands to the simplified concat + SmallVector operands; + // Constants that we will merge, "current running constant" + SmallVector constants; + for (auto op : op->getOperands()) { + if (auto constOp = op.getDefiningOp()) { + constants.push_back(constOp); + continue; + } + if (constants.size() != 0) { + if (constants.size() == 1) { + operands.push_back(constants[0]); + } else { + std::string nextStr; + changed = true; + for (auto str : constants) + nextStr += str.getInput().str(); + + operands.push_back(rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), nextStr)); + } + } + constants.clear(); + if (auto concat = op.getDefiningOp()) { + changed = true; + for (auto op2 : concat->getOperands()) + operands.push_back(op2); + continue; + } + operands.push_back(op); + } + if (constants.size() != 0) { + if (constants.size() == 1) { + operands.push_back(constants[0]); + } else { + std::string nextStr; + changed = true; + for (auto str : constants) + nextStr = nextStr + str.getInput().str(); + operands.push_back(rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), nextStr)); + } + } + if (operands.size() == 0) { + rewriter.replaceOpWithNewOp(op, MemRefType::get({-1}, rewriter.getI8Type()), ""); + return success(); + } + if (operands.size() == 1) { + rewriter.replaceOp(op, operands[0]); + return success(); + } + if (changed) { + rewriter.replaceOpWithNewOp(op, MemRefType::get({-1}, rewriter.getI8Type()), operands); + return success(); + } + return failure(); } }; void SQLStringConcatOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { + MLIRContext *context) { results.insert(context); -} \ No newline at end of file +} diff --git a/lib/sql/Parser.cpp b/lib/sql/Parser.cpp index 25a4776368e2..86125e1b7fc1 100644 --- a/lib/sql/Parser.cpp +++ b/lib/sql/Parser.cpp @@ -1,277 +1,320 @@ +#include +#include #include #include -#include -#include -#include "mlir/IR/Value.h" +#include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Location.h" -#include "mlir/IR/Attributes.h" +#include "mlir/IR/Value.h" #include "llvm/ADT/SmallVector.h" -#include "mlir/IR/BuiltinTypes.h" - #include "sql/SQLDialect.h" #include "sql/SQLOps.h" #include "sql/SQLTypes.h" - using namespace mlir; using namespace sql; enum class ParseType { - Nothing = 0, - Value = 1, - Attribute = 2, + Nothing = 0, + Value = 1, + Attribute = 2, }; struct ParseValue { private: - ParseType ty; - Value value; - Attribute attr; + ParseType ty; + Value value; + Attribute attr; + public: - ParseValue() : ty(ParseType::Nothing), value(nullptr), attr(nullptr) {} - ParseValue(Value value) : ty(ParseType::Value), value(value), attr(nullptr) { - assert(value); - } - ParseValue(Attribute attr) : ty(ParseType::Attribute), value(nullptr), attr(attr) {} + ParseValue() : ty(ParseType::Nothing), value(nullptr), attr(nullptr) {} + ParseValue(Value value) : ty(ParseType::Value), value(value), attr(nullptr) { + assert(value); + } + ParseValue(Attribute attr) + : ty(ParseType::Attribute), value(nullptr), attr(attr) {} - ParseType getType() const { - return ty; - } - Value getValue() const { - // llvm::errs() << "ty: " << type_value << "\n"; - // llvm::errs() << "value: " << value << "\n"; - // llvm::errs() << "attr: " << attr << "\n"; - assert(ty == ParseType::Value); - assert(value); - return value; - } - Attribute getAttr() const { - assert(ty == ParseType::Attribute); - return attr; - } + ParseType getType() const { return ty; } + Value getValue() const { + assert(ty == ParseType::Value); + assert(value); + return value; + } + Attribute getAttr() const { + assert(ty == ParseType::Attribute); + return attr; + } }; enum class ParseMode { - None, - Column, - Table, - Bool, + None, + Column, + Table, + Bool, + Clause }; - -template -std::ostream& operator<<(typename std::enable_if::value, std::ostream>::type& stream, const T& e) -{ - return stream << static_cast::type>(e); +template +std::ostream &operator<<( + typename std::enable_if::value, std::ostream>::type &stream, + const T &e) { + return stream << static_cast::type>(e); } class SQLParser { public: - Location loc; - OpBuilder &builder; - std::string sql; - unsigned int i; + Location loc; + OpBuilder &builder; + std::string sql; + unsigned int i; - static std::vector reservedWords; + static std::vector reservedWords; - - SQLParser(Location loc, OpBuilder &builder, std::string sql, int i) : loc(loc), builder(builder), - sql(sql), i(i) { - } + SQLParser(Location loc, OpBuilder &builder, std::string sql, int i) + : loc(loc), builder(builder), sql(sql), i(i) {} - std::string peek() { - auto [peeked, _] = peekWithLength(); - return peeked; - } + std::string peek() { + auto [peeked, _] = peekWithLength(); + return peeked; + } - std::string pop() { - auto [peeked, len] = peekWithLength(); - i += len; - popWhitespace(); - return peeked; - } + std::string pop() { + auto [peeked, len] = peekWithLength(); + i += len; + popWhitespace(); + return peeked; + } - void popWhitespace() { - // it doesn't recognize - while (i < sql.size() && (sql[i] == ' ' || sql[i] == '\n' || sql[i] == '\t')) { - i++; - } + void popWhitespace() { + // it doesn't recognize + while (i < sql.size() && + (sql[i] == ' ' || sql[i] == '\n' || sql[i] == '\t')) { + i++; } + } - std::pair peekWithLength() { - if (i >= sql.size()) { - return {"", 0}; - } - for (std::string rWord : reservedWords) { - auto token = sql.substr(i, std::min(sql.size()-i, rWord.size())); - std::transform(token.begin(), token.end(), token.begin(), ::toupper); - if (token == rWord) { - return {token, static_cast(token.size())}; - } - } - if (sql[i] == '\'') { // Quoted string - return peekQuotedStringWithLength(); - } - return peekIdentifierWithLength(); + std::pair peekWithLength() { + if (i >= sql.size()) { + return {"", 0}; } - - std::pair peekQuotedStringWithLength() { - if (sql.size() < i || sql[i] != '\'') { - return {"", 0}; - } - for (unsigned int j = i + 1; j < sql.size(); j++) { - if (sql[j] == '\'' && sql[j-1] != '\\') { - return {sql.substr(i + 1, j - (i + 1)), j - i + 2}; // +2 for the two quotes - } - } - return {"", 0}; + for (std::string rWord : reservedWords) { + auto token = sql.substr(i, std::min(sql.size() - i, rWord.size())); + std::transform(token.begin(), token.end(), token.begin(), ::toupper); + if (token == rWord) { + return {token, static_cast(token.size())}; + } } - - std::pair peekIdentifierWithLength() { - std::regex e("[a-zA-Z0-9_*]"); - for (unsigned int j = i; j < sql.size(); j++) { - if (!std::regex_match(std::string(1, sql[j]), e)) { - return {sql.substr(i, j - i), j - i}; - } - } - return {sql.substr(i), static_cast(sql.size())-i}; + if (sql[i] == '\'') { // Quoted string + return peekQuotedStringWithLength(); } + return peekIdentifierWithLength(); + } + std::pair peekQuotedStringWithLength() { + if (sql.size() < i || sql[i] != '\'') { + return {"", 0}; + } + for (unsigned int j = i + 1; j < sql.size(); j++) { + if (sql[j] == '\'' && sql[j - 1] != '\\') { + return {sql.substr(i + 1, j - (i + 1)), + j - i + 2}; // +2 for the two quotes + } + } + return {"", 0}; + } - bool is_number(std::string* s){ - std::string::const_iterator it = s->begin(); - while (it != s->end() && std::isdigit(*it)) ++it; - return !s->empty() && it == s->end(); + std::pair peekIdentifierWithLength() { + std::regex e("[a-zA-Z0-9_*]"); + for (unsigned int j = i; j < sql.size(); j++) { + if (!std::regex_match(std::string(1, sql[j]), e)) { + return {sql.substr(i, j - i), j - i}; + } } + return {sql.substr(i), static_cast(sql.size()) - i}; + } + bool is_number(std::string *s) { + std::string::const_iterator it = s->begin(); + while (it != s->end() && std::isdigit(*it)) + ++it; + return !s->empty() && it == s->end(); + } - // Parse the next command, if any - ParseValue parseNext(ParseMode mode) { - // for (unsigned int j = i; j < sql.size(); j++) { - // auto peekStr = peek(); - // pop(); - // llvm::errs() << "peekStrTest: " << i << " " << peekStr << "\n"; - // } - if (i >= sql.size()) { - return ParseValue(); - } - auto peekStr = peek(); - llvm::errs() << "peekStr: " << peekStr << "\n"; - assert(peekStr.size() > 0); + // Parse the next command, if any + ParseValue parseNext(ParseMode mode) { + // for (unsigned int j = i; j < sql.size(); j++) { + // auto peekStr = peek(); + // pop(); + // llvm::errs() << "peekStrTest: " << i << " " << peekStr << "\n"; + // } + if (i >= sql.size()) { + return ParseValue(); + } + auto peekStr = peek(); + llvm::errs() << "peekStr: " << peekStr << "\n"; + assert(peekStr.size() > 0); - if (peekStr == "SELECT") { - assert(mode == ParseMode::None || mode == ParseMode::Table); + if (peekStr == "SELECT") { + assert(mode == ParseMode::None || mode == ParseMode::Table); + pop(); + peekStr = peek(); + if (peekStr == "DISTINCT") { + pop(); + // do something different here + } + llvm::SmallVector columns; + bool hasColumns = true; + bool hasWhere = false; + int limit = -1; + Value table = nullptr; + Value where = nullptr; + while (true) { + peekStr = peek(); + if (peekStr == "") + break; + if (hasColumns) { + if (peekStr == "FROM") { pop(); - peekStr = peek(); - if (peekStr == "DISTINCT") { - pop(); - // do something different here - } - llvm::SmallVector columns; - bool hasColumns = true; - bool hasWhere = false; - bool selectAll = false; - int limit = -1; - Value table = nullptr; - while (true) { - peekStr = peek(); - if (hasColumns) { - if (peekStr == "FROM") { - pop(); - table = parseNext(ParseMode::Table).getValue(); - llvm::errs() << "table: " << table << "\n"; - hasColumns = false; - break; - } - if (peekStr == "*") { - pop(); - selectAll = true; - continue; - } - if (peekStr == ",") { - pop(); - llvm::errs() << "comma\n"; - continue; - } - ParseValue col = parseNext(ParseMode::Column); - if (col.getType() == ParseType::Nothing) { - hasColumns = false; - break; - } else { - columns.push_back(col.getValue()); - } - - } else if (peekStr == "WHERE") { - pop(); - hasWhere = true; - } else if (peekStr == "LIMIT"){ - pop(); - peekStr = peek(); - if (peekStr == "ALL"){ - pop(); - } else if (is_number(&peekStr)){ - pop(); - limit = std::stoi(peekStr); - } - } else { - // break; - assert(0 && " additional clauses like where/etc not yet handled"); - } - } - if (!table){ - llvm::errs() << " table is null: " << table << "\n"; - table = builder.create(loc, ExprType::get(builder.getContext()), builder.getStringAttr("")).getResult(); + table = parseNext(ParseMode::Table).getValue(); + hasColumns = false; + } else { + if (peekStr == ",") { + pop(); + continue; } - // if (selectAll){ - // assert(table && "table cannot be null"); - // return ParseValue(builder.create(loc, ExprType::get(builder.getContext()), table).getResult()); - // } else { - return ParseValue(builder.create(loc, ExprType::get(builder.getContext()), columns, table, selectALl, limit).getResult()); - // } - } else if (is_number(&peekStr)){ - pop(); - return ParseValue(builder.create(loc, ExprType::get(builder.getContext()), builder.getStringAttr(peekStr)).getResult()); - } else if (mode == ParseMode::Column) { - // do we need this?? - if (peekStr == "*") { - pop(); - return ParseValue(builder.create(loc, ExprType::get(builder.getContext()), builder.getStringAttr(peekStr)).getResult()); + ParseValue col = parseNext(ParseMode::Column); + if (col.getType() == ParseType::Nothing) { + hasColumns = false; + break; + } else { + columns.push_back(col.getValue()); } + } + } else if (peekStr == "WHERE") { + pop(); + ParseValue clause = parseNext(ParseMode::Bool); + if (clause.getType() == ParseType::Nothing) { + assert(0 && "where clause not recognized"); + } else { + where = builder.create(loc, ExprType::get(builder.getContext()), + clause.getValue()).getResult(); + } + } else if (peekStr == "LIMIT") { + pop(); + peekStr = peek(); + llvm::errs() << "limit: " << peekStr << "\n"; + if (peekStr == "ALL") { pop(); - return ParseValue(builder.create(loc, ExprType::get(builder.getContext()), builder.getStringAttr(peekStr)).getResult()); - } else if (mode == ParseMode::Table) { - pop(); - return ParseValue(builder.create(loc,ExprType::get(builder.getContext()), builder.getStringAttr(peekStr)).getResult()); - } else if (mode == ParseMode::Bool) { - // col = peekStr; - pop(); - - } else if (peekStr == "(") { + } else if (is_number(&peekStr)) { + llvm::errs() << "limit recognized: " + << "\n"; pop(); - ParseValue res = parseNext(ParseMode::None); - assert(peek() == ")"); - pop(); - return res; - } else if (peekStr == ")") { - return ParseValue(); - } - llvm::errs() << " Unknown token to parse: " << peekStr << "\n"; - llvm_unreachable("Unknown token to parse"); - } + limit = std::stoi(peekStr); + } else { + assert(0 && "not yet handled limit var"); + } + } else { + // break; + llvm::errs() << "peekstr that throws an error" << peekStr << "\n"; + assert(0 && " additional clauses like where/etc not yet handled"); + } + } + if (!table) { + llvm::errs() << " table is null: " << table << "\n"; + table = builder.create(loc, ExprType::get(builder.getContext()), + builder.getStringAttr("")).getResult(); + } + if (!where){ + llvm::errs() << " where is null: " << table << "\n"; + Value clause = builder.create(loc, BoolType::get(builder.getContext()), + builder.getStringAttr("")).getResult(); + where = builder.create(loc, ExprType::get(builder.getContext()), + clause).getResult(); + } + return ParseValue( + builder.create(loc, ExprType::get(builder.getContext()), + columns, table, where, limit).getResult()); + } else if (is_number(&peekStr)) { + pop(); + return ParseValue(builder.create(loc, + ExprType::get(builder.getContext()), + builder.getStringAttr(peekStr)).getResult()); + } else if (mode == ParseMode::Column) { + if (peekStr == "*") { + pop(); + return ParseValue( + builder + .create(loc, ExprType::get(builder.getContext())) + .getResult()); + } + pop(); + return ParseValue( + builder + .create(loc, ExprType::get(builder.getContext()), + builder.getStringAttr(peekStr)) + .getResult()); + } else if (mode == ParseMode::Table) { + pop(); + return ParseValue( + builder + .create(loc, ExprType::get(builder.getContext()), + builder.getStringAttr(peekStr)) + .getResult()); + } else if (mode == ParseMode::Bool) { + // col = peekStr; + ParseValue left = parseNext(ParseMode::Clause); + peekStr = peek(); + if (peekStr == "AND") { + pop(); + ParseValue right = parseNext(ParseMode::Bool); + return ParseValue( + builder.create(loc, BoolType::get(builder.getContext()), left.getValue(), right.getValue()).getResult() + ); + } else if (peekStr == "OR") { + pop(); + ParseValue right = parseNext(ParseMode::Bool); + return ParseValue( + builder.create(loc, BoolType::get(builder.getContext()), left.getValue(), + right.getValue()).getResult() + ); + } else return left; + + } else if (mode == ParseMode::Clause){ + std::string clause = pop(); + clause += " " + pop(); + clause += " " + pop(); + return ParseValue( + builder.create(loc, BoolType::get(builder.getContext()), + builder.getStringAttr(clause)).getResult() + ); + } else if (peekStr == "(") { + pop(); + ParseValue res = parseNext(ParseMode::None); + assert(peek() == ")"); + pop(); + return res; + } else if (peekStr == ")") { + return ParseValue(); + } + llvm::errs() << " Unknown token to parse: " << peekStr << "\n"; + llvm_unreachable("Unknown token to parse"); + } }; std::vector SQLParser::reservedWords = { - "(", ")", ">=", "<=", "!=", ",", "=", ">", "<", ",", "SELECT", "DISTINCT", "INSERT INTO", "VALUES", "UPDATE", "DELETE FROM", "WHERE", "FROM", "SET", "AS" -}; + "(", ")", ">=", "<=", "!=", + ",", "=", ">", "<", ",", + "SELECT", "DISTINCT", "INSERT INTO", "VALUES", "UPDATE", + "DELETE FROM", "WHERE", "FROM", "SET", "AS"}; +mlir::Value parseSQL(mlir::Location loc, mlir::OpBuilder &builder, + std::string str) { + SQLParser parser(loc, builder, std::string(str), 0); + auto resOp = parser.parseNext(ParseMode::None); -mlir::Value parseSQL(mlir::Location loc, mlir::OpBuilder& builder, std::string str) { - SQLParser parser(loc, builder, std::string(str), 0); - auto resOp = parser.parseNext(ParseMode::None); - - return resOp.getValue(); + return resOp.getValue(); } \ No newline at end of file diff --git a/lib/sql/Passes/SQLLower.cpp b/lib/sql/Passes/SQLLower.cpp index eaa26ddb2458..eb0483e4bb32 100644 --- a/lib/sql/Passes/SQLLower.cpp +++ b/lib/sql/Passes/SQLLower.cpp @@ -11,7 +11,6 @@ //===----------------------------------------------------------------------===// #include "PassDetails.h" -#include "polygeist/Ops.h" #include "mlir/Analysis/CallGraph.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Async/IR/Async.h" @@ -20,10 +19,11 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/Passes.h" +#include "polygeist/Ops.h" +#include "sql/Passes/Passes.h" +#include "sql/SQLOps.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallPtrSet.h" -#include "sql/SQLOps.h" -#include "sql/Passes/Passes.h" #include #include @@ -41,7 +41,6 @@ struct SQLLower : public SQLLowerBase { } // end anonymous namespace - struct NumResultsOpLowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -53,35 +52,33 @@ struct NumResultsOpLowering : public OpRewritePattern { symbolTable.getSymbolTable(module); // 1) make sure the postgres_getresult function is declared - auto rowsfn = dyn_cast_or_null( - symbolTable.lookupSymbolIn(module, rewriter.getStringAttr("PQntuples"))); + auto rowsfn = dyn_cast_or_null(symbolTable.lookupSymbolIn( + module, rewriter.getStringAttr("PQntuples"))); // 2) convert the args to valid args to postgres_getresult abi Value arg = op.getHandle(); - arg = rewriter.create(op.getLoc(), - rewriter.getI8Type(), arg); + arg = rewriter.create(op.getLoc(), rewriter.getI8Type(), + arg); arg = rewriter.create( op.getLoc(), LLVM::LLVMPointerType::get(rewriter.getI8Type()), arg); - - arg = rewriter.create(op.getLoc(), - rowsfn.getFunctionType().getInput(0), arg); + + arg = rewriter.create( + op.getLoc(), rowsfn.getFunctionType().getInput(0), arg); // 3) call and replace - Value args[] = {arg}; - - Value res = - rewriter.create(op.getLoc(), rowsfn, args) - ->getResult(0); + Value args[] = {arg}; + + Value res = rewriter.create(op.getLoc(), rowsfn, args) + ->getResult(0); - rewriter.replaceOpWithNewOp( - op, rewriter.getIndexType(), res); + rewriter.replaceOpWithNewOp(op, rewriter.getIndexType(), + res); return success(); } }; - struct GetValueOpLowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -93,8 +90,8 @@ struct GetValueOpLowering : public OpRewritePattern { symbolTable.getSymbolTable(module); // 1) make sure the postgres_getresult function is declared - auto valuefn = dyn_cast_or_null( - symbolTable.lookupSymbolIn(module, rewriter.getStringAttr("PQgetvalue"))); + auto valuefn = dyn_cast_or_null(symbolTable.lookupSymbolIn( + module, rewriter.getStringAttr("PQgetvalue"))); auto atoifn = dyn_cast_or_null( symbolTable.lookupSymbolIn(module, rewriter.getStringAttr("atoi"))); @@ -102,48 +99,48 @@ struct GetValueOpLowering : public OpRewritePattern { // 2) convert the args to valid args to postgres_getresult abi Value handle = op.getHandle(); handle = rewriter.create(op.getLoc(), - rewriter.getI64Type(), handle); + rewriter.getI64Type(), handle); handle = rewriter.create( op.getLoc(), LLVM::LLVMPointerType::get(rewriter.getI8Type()), handle); - handle = rewriter.create(op.getLoc(), - valuefn.getFunctionType().getInput(0), handle); + handle = rewriter.create( + op.getLoc(), valuefn.getFunctionType().getInput(0), handle); Value row = op.getRow(); - row = rewriter.create(op.getLoc(), - valuefn.getFunctionType().getInput(1), row); - Value column = op.getColumn(); - column = rewriter.create(op.getLoc(), - valuefn.getFunctionType().getInput(2), column); - - Value args[] = {handle, row, column}; - - Value res = - rewriter.create(op.getLoc(), valuefn, args) - ->getResult(0); + row = rewriter.create( + op.getLoc(), valuefn.getFunctionType().getInput(1), row); + Value column = op.getColumn(); + column = rewriter.create( + op.getLoc(), valuefn.getFunctionType().getInput(2), column); + + Value args[] = {handle, row, column}; + + Value res = rewriter.create(op.getLoc(), valuefn, args) + ->getResult(0); + + Value args2[] = {res}; - Value args2[] = {res}; - Value res2 = rewriter.create(op.getLoc(), atoifn, args2) ->getResult(0); if (op.getType() != res2.getType()) { - if (op.getType().isa()) - res2 = rewriter.create(op.getLoc(), - op.getType(), res2); - else if (auto IT = op.getType().dyn_cast()) { - auto IT2 = res2.getType().dyn_cast(); - if (IT.getWidth() < IT2.getWidth()) { - res2 = rewriter.create(op.getLoc(), - op.getType(), res2); - } else if (IT.getWidth() > IT2.getWidth()) { - res2 = rewriter.create(op.getLoc(), - op.getType(), res2); - } else assert(0 && "illegal integer type conversion"); - } else { - assert(0 && "illegal type conversion"); - } + if (op.getType().isa()) + res2 = rewriter.create(op.getLoc(), op.getType(), + res2); + else if (auto IT = op.getType().dyn_cast()) { + auto IT2 = res2.getType().dyn_cast(); + if (IT.getWidth() < IT2.getWidth()) { + res2 = + rewriter.create(op.getLoc(), op.getType(), res2); + } else if (IT.getWidth() > IT2.getWidth()) { + res2 = + rewriter.create(op.getLoc(), op.getType(), res2); + } else + assert(0 && "illegal integer type conversion"); + } else { + assert(0 && "illegal type conversion"); + } } rewriter.replaceOp(op, res2); @@ -151,20 +148,21 @@ struct GetValueOpLowering : public OpRewritePattern { } }; +struct ConstantStringOpLowering + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; -struct ConstantStringOpLowering : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(sql::SQLConstantStringOp op, + LogicalResult matchAndRewrite(sql::SQLConstantStringOp op, PatternRewriter &rewriter) const final { auto module = op->getParentOfType(); SymbolTableCollection symbolTable; symbolTable.getSymbolTable(module); - for (auto u: op.getResult().getUsers()){ - if (isa(u)) return failure(); + for (auto u : op.getResult().getUsers()) { + if (isa(u)) + return failure(); } - auto expr = op.getInput().str(); + auto expr = op.getInput().str(); auto name = "str" + std::to_string((long long int)(Operation *)op); auto MT = MemRefType::get({expr.size() + 1}, rewriter.getI8Type()); // auto type = MemRefType::get(mt.getShape(), mt.getElementType(), {}); @@ -173,37 +171,93 @@ struct ConstantStringOpLowering : public OpRewritePattern data(expr.begin(), expr.end()); data.push_back('\0'); auto attr = DenseElementsAttr::get( - RankedTensorType::get(MT.getShape(), MT.getElementType()), data); - + RankedTensorType::get(MT.getShape(), MT.getElementType()), data); + auto loc = op.getLoc(); - rewriter.replaceOpWithNewOp(op, MemRefType::get({-1}, rewriter.getI8Type()), getglob.getResult()); + rewriter.replaceOpWithNewOp( + op, MemRefType::get({-1}, rewriter.getI8Type()), getglob.getResult()); rewriter.setInsertionPointToStart(module.getBody()); - auto res = rewriter.create(loc, rewriter.getStringAttr(name), - mlir::StringAttr(), mlir::TypeAttr::get(MT), attr, rewriter.getUnitAttr(), /*alignment*/nullptr); + auto res = rewriter.create( + loc, rewriter.getStringAttr(name), mlir::StringAttr(), + mlir::TypeAttr::get(MT), attr, rewriter.getUnitAttr(), + /*alignment*/ nullptr); return success(); } }; -// struct StringConcatOpLowering : public OpRewritePattern { -// using OpRewritePattern::OpRewritePattern; -// LogicalResult matchAndRewrite(sql::SQLStringConcatOp op, -// PatternRewriter &rewriter) const final { -// auto module = op->getParentOfType(); +struct BoolToStringOpLowering : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; -// SymbolTableCollection symbolTable; -// symbolTable.getSymbolTable(module); - -// return success(); -// } -// }; + LogicalResult matchAndRewrite(sql::SQLBoolToStringOp op, + PatternRewriter &rewriter) const final { + auto module = op->getParentOfType(); + + SymbolTableCollection symbolTable; + symbolTable.getSymbolTable(module); + + // 2) convert the args to valid args to postgres_getresult abi + Value expr = op.getExpr(); + auto definingOp = expr.getDefiningOp(); + if (auto andOp = dyn_cast(definingOp)) { + Value left = rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), + andOp.getLeft()); + Value right = rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), + andOp.getRight()); + Value args[] = {left, rewriter.create( + op.getLoc(), + MemRefType::get({-1}, rewriter.getI8Type()), + "AND "), right}; + Value res = rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), args); + rewriter.replaceOp(op, res); + } else if (auto orOp = dyn_cast(definingOp)) { + Value left = rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), + orOp.getLeft()); + Value right = rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), + orOp.getRight()); + Value args[] = {left, rewriter.create( + op.getLoc(), + MemRefType::get({-1}, rewriter.getI8Type()), + "OR "), right}; + Value res = rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), args); + rewriter.replaceOp(op, res); + } else if (auto calcBoolOp = dyn_cast(definingOp)){ + auto expr = calcBoolOp.getExpr(); + if (expr == "") { + Value res = rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), + ""); + rewriter.replaceOp(op, res); + return success(); + } + + Value res = rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), expr); + Value args[] = {res, rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), " ")}; + res = rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), args); + rewriter.replaceOp(op, res); + } else { + assert(0 && "unknown type to convert to string"); + } + + return success(); + } +}; struct ToStringOpLowering : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(sql::SQLToStringOp op, + LogicalResult matchAndRewrite(sql::SQLToStringOp op, PatternRewriter &rewriter) const final { auto module = op->getParentOfType(); @@ -213,72 +267,128 @@ struct ToStringOpLowering : public OpRewritePattern { // 2) convert the args to valid args to postgres_getresult abi Value expr = op.getExpr(); auto definingOp = expr.getDefiningOp(); - if (auto selectOp = dyn_cast(definingOp)){ - Value current = rewriter.create(op.getLoc(), - MemRefType::get({-1}, rewriter.getI8Type()), "SELECT "); - bool prevColumn = false; - auto columns = selectOp.getColumns(); - for (mlir::Value v : columns) { - if (prevColumn) { - Value args[] = { current, rewriter.create(op.getLoc(), - MemRefType::get({-1}, rewriter.getI8Type()), ", ") }; - current = rewriter.create(op.getLoc(), - MemRefType::get({-1}, rewriter.getI8Type()),args); - } - Value col = rewriter.create(op.getLoc(), - MemRefType::get({-1}, rewriter.getI8Type()), v); - Value args[] = { col, rewriter.create(op.getLoc(), - MemRefType::get({-1}, rewriter.getI8Type()), " ")}; - col = rewriter.create(op.getLoc(), - MemRefType::get({-1}, rewriter.getI8Type()), args); - Value args2[] = { current, col }; - current = rewriter.create(op.getLoc(), - MemRefType::get({-1}, rewriter.getI8Type()), args2); - prevColumn = true; - } - auto tableOp = selectOp.getTable().getDefiningOp(); - if (tableOp) { - Value args[] = { current, rewriter.create(op.getLoc(), - MemRefType::get({-1}, rewriter.getI8Type()), "FROM ") }; - current = rewriter.create(op.getLoc(), - MemRefType::get({-1}, rewriter.getI8Type()),args); - Value args2[] = { current, rewriter.create(op.getLoc(), - MemRefType::get({-1}, rewriter.getI8Type()), selectOp.getTable()) }; - current = rewriter.create(op.getLoc(), - MemRefType::get({-1}, rewriter.getI8Type()),args2); + if (auto selectOp = dyn_cast(definingOp)) { + Value current = rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), "SELECT "); + + bool prevColumn = false; + auto columns = selectOp.getColumns(); + for (mlir::Value v : columns) { + if (prevColumn) { + Value args[] = { + current, rewriter.create( + op.getLoc(), + MemRefType::get({-1}, rewriter.getI8Type()), ", ")}; + current = rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), args); } - rewriter.replaceOp(op, current); - } else if (auto selectAllOp = dyn_cast(definingOp)){ - auto table = rewriter.create(op.getLoc(), - MemRefType::get({-1}, rewriter.getI8Type()), selectAllOp.getTable()); - Value res = rewriter.create(op.getLoc(), - MemRefType::get({-1}, rewriter.getI8Type()), "SELECT * FROM "); - Value args[] = { res, table }; - res = rewriter.create(op.getLoc(), - MemRefType::get({-1}, rewriter.getI8Type()),args); - - rewriter.replaceOp(op, res); + Value col = rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), v); + Value args[] = { + col, + rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), " ")}; + col = rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), args); + Value args2[] = {current, col}; + current = rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), args2); + prevColumn = true; + } + + auto tableOp = selectOp.getTable().getDefiningOp(); + if (tableOp) { + Value args[] = { + current, rewriter.create( + op.getLoc(), + MemRefType::get({-1}, rewriter.getI8Type()), "FROM ")}; + current = rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), args); + Value args2[] = {current, + rewriter.create( + op.getLoc(), + MemRefType::get({-1}, rewriter.getI8Type()), + selectOp.getTable())}; + current = rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), args2); + } + + Value whereVal = rewriter.create( + op.getLoc(), + MemRefType::get({-1}, rewriter.getI8Type()), + selectOp.getWhere()); + if (whereVal) { + auto whereOp = selectOp.getWhere().getDefiningOp(); + Value args[] = { + current, rewriter.create( + op.getLoc(), + MemRefType::get({-1}, rewriter.getI8Type()), "WHERE ")}; + current = rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), args); + Value args2[] = {current, whereVal}; + current = rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), args2); + } + + + auto limit = selectOp.getLimit(); + if (limit >= 0) { + Value args[] = {current, + rewriter.create( + op.getLoc(), + MemRefType::get({-1}, rewriter.getI8Type()), + "LIMIT ")}; + current = rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), args); + Value args2[] = {current, + rewriter.create( + op.getLoc(), + MemRefType::get({-1}, rewriter.getI8Type()), + std::to_string(limit))}; + current = rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), args2); + } + rewriter.replaceOp(op, current); } else if (auto tabOp = dyn_cast(definingOp)) { - Value res = rewriter.create(op.getLoc(), - MemRefType::get({-1}, rewriter.getI8Type()), tabOp.getExpr()); - rewriter.replaceOp(op, res); + Value res = rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), + tabOp.getExpr()); + Value args[] = { res, rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), " ")}; + res = rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), args); + rewriter.replaceOp(op, res); + } else if (auto allColOp = dyn_cast(definingOp)) { + Value res = rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), "*"); + rewriter.replaceOp(op, res); } else if (auto colOp = dyn_cast(definingOp)) { - Value res = rewriter.create(op.getLoc(), - MemRefType::get({-1}, rewriter.getI8Type()), colOp.getExpr()); - rewriter.replaceOp(op, res); - } else if (auto intOp = dyn_cast(definingOp)){ - Value res = rewriter.create(op.getLoc(), - MemRefType::get({-1}, rewriter.getI8Type()), intOp.getExpr()); - llvm::errs() << "intOp: " << intOp.getExpr() << "\n"; - rewriter.replaceOp(op, res); + Value res = rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), + colOp.getExpr()); + rewriter.replaceOp(op, res); + } else if (auto intOp = dyn_cast(definingOp)) { + Value res = rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), + intOp.getExpr()); + llvm::errs() << "intOp: " << intOp.getExpr() << "\n"; + rewriter.replaceOp(op, res); + } else if (auto whereOp = dyn_cast(definingOp)) { + Value res = rewriter.create( + op.getLoc(), + MemRefType::get({-1}, rewriter.getI8Type()), + whereOp.getExpr()); + rewriter.replaceOp(op, res); } else { - assert(0 && "unknown type to convert to string"); + assert(0 && "unknown type to convert to string"); } - + return success(); } }; + + struct ExecuteOpLowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -296,39 +406,37 @@ struct ExecuteOpLowering : public OpRewritePattern { // 2) convert the args to valid args to postgres_getresult abi Value conn = op.getConn(); conn = rewriter.create(op.getLoc(), - rewriter.getI8Type(), conn); + rewriter.getI8Type(), conn); conn = rewriter.create( op.getLoc(), LLVM::LLVMPointerType::get(rewriter.getI8Type()), conn); - conn = rewriter.create(op.getLoc(), - executefn.getFunctionType().getInput(0), conn); - - Value command = op.getCommand(); - // auto name = "str" + std::to_string((long long int)(Operation *)command.getDefiningOp()); - command = rewriter.create(op.getLoc(), - MemRefType::get({-1}, rewriter.getI8Type()), command); + conn = rewriter.create( + op.getLoc(), executefn.getFunctionType().getInput(0), conn); + + Value command = op.getCommand(); + // auto name = "str" + std::to_string((long long int)(Operation + // *)command.getDefiningOp()); + command = rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), command); llvm::errs() << "command: " << command << "\n"; llvm::errs() << "command type: " << command.getType() << "\n"; // auto type = MemRefType::get({-1}, rewriter.getI8Type()); - // auto globalOp = rewriter.create( - // op.getLoc(), type, /* isConstant */ true, LLVM::Linkage::Internal, name, - // mlir::Attribute(), + // op.getLoc(), type, /* isConstant */ true, LLVM::Linkage::Internal, + // name, mlir::Attribute(), // /* alignment */ 0, /* addrSpace */ 0); - // 3) call and replace - Value args[] = {conn, command}; - + Value args[] = {conn, command}; + Value res = rewriter.create(op.getLoc(), executefn, args) ->getResult(0); - res = rewriter.create(op.getLoc(), - LLVM::LLVMPointerType::get(rewriter.getI8Type()), res); - res = rewriter.create( - op.getLoc(), rewriter.getI64Type(), res); - res = rewriter.create(op.getLoc(), - op.getType(), res); + res = rewriter.create( + op.getLoc(), LLVM::LLVMPointerType::get(rewriter.getI8Type()), res); + res = rewriter.create(op.getLoc(), rewriter.getI64Type(), + res); + res = rewriter.create(op.getLoc(), op.getType(), res); rewriter.replaceOp(op, res); @@ -349,28 +457,26 @@ void SQLLower::runOnOperation() { mlir::Type argtypes[] = {MemRefType::get({-1}, builder.getI8Type())}; mlir::Type rettypes[] = {builder.getI64Type()}; - auto fn = - builder.create(module.getLoc(), "PQntuples", - builder.getFunctionType(argtypes, rettypes)); + auto fn = builder.create( + module.getLoc(), "PQntuples", + builder.getFunctionType(argtypes, rettypes)); SymbolTable::setSymbolVisibility(fn, SymbolTable::Visibility::Private); } if (!dyn_cast_or_null(symbolTable.lookupSymbolIn( module, builder.getStringAttr("PQgetvalue")))) { - mlir::Type argtypes[] = { - MemRefType::get({-1}, builder.getI8Type()), - builder.getI64Type(), - builder.getI64Type()}; + mlir::Type argtypes[] = {MemRefType::get({-1}, builder.getI8Type()), + builder.getI64Type(), builder.getI64Type()}; mlir::Type rettypes[] = {MemRefType::get({-1}, builder.getI8Type())}; - auto fn = - builder.create(module.getLoc(), "PQgetvalue", - builder.getFunctionType(argtypes, rettypes)); + auto fn = builder.create( + module.getLoc(), "PQgetvalue", + builder.getFunctionType(argtypes, rettypes)); SymbolTable::setSymbolVisibility(fn, SymbolTable::Visibility::Private); } - if (!dyn_cast_or_null( - symbolTable.lookupSymbolIn(module, builder.getStringAttr("PQexec")))) { + if (!dyn_cast_or_null(symbolTable.lookupSymbolIn( + module, builder.getStringAttr("PQexec")))) { mlir::Type argtypes[] = {MemRefType::get({-1}, builder.getI8Type()), MemRefType::get({-1}, builder.getI8Type())}; mlir::Type rettypes[] = {MemRefType::get({-1}, builder.getI8Type())}; @@ -383,7 +489,8 @@ void SQLLower::runOnOperation() { if (!dyn_cast_or_null( symbolTable.lookupSymbolIn(module, builder.getStringAttr("atoi")))) { mlir::Type argtypes[] = {MemRefType::get({-1}, builder.getI8Type())}; - // mlir::Type argtypes[] = {LLVM::LLVMPointerType::get(builder.getI64Type())}; + // mlir::Type argtypes[] = + // {LLVM::LLVMPointerType::get(builder.getI64Type())}; // todo use data layout mlir::Type rettypes[] = {builder.getI64Type()}; @@ -398,13 +505,13 @@ void SQLLower::runOnOperation() { patterns.insert(&getContext()); patterns.insert(&getContext()); patterns.insert(&getContext()); + patterns.insert(&getContext()); patterns.insert(&getContext()); - for (auto *dialect : getContext().getLoadedDialects()) - dialect->getCanonicalizationPatterns(patterns); + dialect->getCanonicalizationPatterns(patterns); for (RegisteredOperationName op : getContext().getRegisteredOperations()) - op.getCanonicalizationPatterns(patterns, &getContext()); + op.getCanonicalizationPatterns(patterns, &getContext()); GreedyRewriteConfig config; (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), @@ -416,5 +523,5 @@ namespace sql { std::unique_ptr createSQLLowerPass() { return std::make_unique(); } -} // namespace polygeist +} // namespace sql } // namespace mlir From 0b857244130a6026b412eb2cf45ace7681d220ad Mon Sep 17 00:00:00 2001 From: Carl Guo Date: Tue, 14 Nov 2023 16:21:29 -0500 Subject: [PATCH 15/15] now with sql --- include/CMakeLists.txt | 4 +- include/sql/Parser.h | 2 - include/sql/SQLOps.td | 23 ++++++++++ lib/CMakeLists.txt | 4 +- lib/sql/CMakeLists.txt | 3 +- test_with_pragma.c | 71 ++++++++++++++++-------------- tools/cgeist/CMakeLists.txt | 8 ++-- tools/polygeist-opt/CMakeLists.txt | 4 ++ 8 files changed, 78 insertions(+), 41 deletions(-) diff --git a/include/CMakeLists.txt b/include/CMakeLists.txt index f5e589845278..7719ec0de2aa 100644 --- a/include/CMakeLists.txt +++ b/include/CMakeLists.txt @@ -1,2 +1,4 @@ add_subdirectory(polygeist) -add_subdirectory(sql) \ No newline at end of file +if (ENABLE_SQL) + add_subdirectory(sql) +endif() \ No newline at end of file diff --git a/include/sql/Parser.h b/include/sql/Parser.h index 99ba7579173e..d7ec2f5a1e37 100644 --- a/include/sql/Parser.h +++ b/include/sql/Parser.h @@ -7,8 +7,6 @@ //===----------------------------------------------------------------------===// - - #ifndef SQLPARSER_H #define SQLPARSER_H diff --git a/include/sql/SQLOps.td b/include/sql/SQLOps.td index eafeaac1c5d5..8191f71a6673 100644 --- a/include/sql/SQLOps.td +++ b/include/sql/SQLOps.td @@ -49,6 +49,16 @@ def WhereOp: SQL_Op<"where", [Pure]> { let hasCanonicalizer = 0; } +def BoolArithOp: SQL_Op<"bool_arith", [Pure]> { + let summary = "bool_arith op"; + + let arguments = (ins SQLBoolType:$left, SQLBoolType:$right, StrAttr:$op); + let results = (outs SQLBoolType:$result); + + let hasFolder = 0; + let hasCanonicalizer = 0; +} + def CalcBoolOp: SQL_Op<"calc_bool", [Pure]> { let summary = "calc_bool op"; @@ -59,6 +69,19 @@ def CalcBoolOp: SQL_Op<"calc_bool", [Pure]> { let hasCanonicalizer = 0; } + + + +def ArithOp: SQL_Op<"arith", [Pure]> { + let summary = "arith op"; + + let arguments = (ins SQLExprType:$left, SQLExprType:$right, StrAttr:$op); + let results = (outs SQLExprType:$result); + + let hasFolder = 0; + let hasCanonicalizer = 0; +} + def AndOp: SQL_Op<"and", [Pure]> { let summary = "and op"; diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index f5e589845278..7719ec0de2aa 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -1,2 +1,4 @@ add_subdirectory(polygeist) -add_subdirectory(sql) \ No newline at end of file +if (ENABLE_SQL) + add_subdirectory(sql) +endif() \ No newline at end of file diff --git a/lib/sql/CMakeLists.txt b/lib/sql/CMakeLists.txt index fa5bae87f548..64a2f92db852 100644 --- a/lib/sql/CMakeLists.txt +++ b/lib/sql/CMakeLists.txt @@ -3,6 +3,7 @@ Types.cpp Dialect.cpp Ops.cpp Parser.cpp +NewParser.cpp ADDITIONAL_HEADER_DIRS @@ -13,7 +14,7 @@ MLIRSQLOpsIncGen # MLIRSQLTypesIncGen LINK_LIBS PUBLIC -MLIRIR +MLIRIR sqlparse_lib ) add_subdirectory(Passes) \ No newline at end of file diff --git a/test_with_pragma.c b/test_with_pragma.c index d12558a64839..7c144f328c5b 100644 --- a/test_with_pragma.c +++ b/test_with_pragma.c @@ -1,52 +1,57 @@ +#include #include #include -#include // PGresult *PQexec(PGconn*, const char* command); // PQgetvalue -// %7 = call @PQexec(%2, %6) : (memref, memref) -> memref -#pragma lower_to(num_rows_fn, "sql.num_results") -int num_rows_fn(size_t);// char* +// %7 = call @PQexec(%2, %6) : (memref, memref) -> +// memref +// #pragma lower_to(num_rows_fn, "sql.num_results") +// int num_rows_fn(size_t);// char* -#pragma lower_to(get_value_fn_int, "sql.get_value") -int get_value_fn_int(size_t, int, int); +// #pragma lower_to(get_value_fn_int, "sql.get_value") +// int get_value_fn_int(size_t, int, int); -#pragma lower_to(get_value_fn_double, "sql.get_value") -double get_value_fn_double(size_t, int, int); +// #pragma lower_to(get_value_fn_double, "sql.get_value") +// double get_value_fn_double(size_t, int, int); +// #pragma lower_to(execute, "sql.execute") +// PGresult* execute(size_t, char*); void do_exit(PGconn *conn) { - PQfinish(conn); - exit(1); + PQfinish(conn); + exit(1); } int main() { - - PGconn *conn = PQconnectdb("user=janbodnar dbname=testdb"); - - if (PQstatus(conn) == CONNECTION_BAD) { - - fprintf(stderr, "Connection to database failed: %s\n", - PQerrorMessage(conn)); - do_exit(conn); - } - PGresult *res = PQexec(conn, "SELECT VERSION()"); - - if (PQresultStatus(res) != PGRES_TUPLES_OK) { + PGconn *conn = PQconnectdb("user=carl dbname=testdb"); - printf("No data retrieved\n"); - PQclear(res); - do_exit(conn); - } - - printf("%s\n", PQgetvalue(res, 0, 0)); - printf("%d\n", get_value_fn_int((size_t)res, 0, 0)); - printf("%d\n", num_rows_fn((size_t)res)); - // res, 0, 0)); + if (PQstatus(conn) == CONNECTION_BAD) { + fprintf(stderr, "Connection to database failed: %s\n", + PQerrorMessage(conn)); + do_exit(conn); + } + + // PGresult *res = PQexec(conn, "SELECT 17"); + PGresult *res = PQexec(conn, "SELECT a FROM table1"); + PGresult *res1 = PQexec(conn, "SELECT * FROM table1 WHERE b > 10 OR c < 10 AND a <= 20"); + PGresult *res2 = PQexec(conn, "SELECT * FROM table1 WHERE b > 10 AND c < 10"); + PGresult *res3 = PQexec(conn, "SELECT b, c FROM table1 WHERE a <= 10 LIMIT 10"); + // PGresult *res3 = PQexec(conn, "SELECT b, c FROM table1 LIMIT ALL"); + if (PQresultStatus(res) != PGRES_TUPLES_OK) { + + printf("No data retrieved\n"); PQclear(res); - PQfinish(conn); + do_exit(conn); + } + + PQclear(res); + PQclear(res1); + PQclear(res2); + PQclear(res3); + PQfinish(conn); - return 0; + return 0; } diff --git a/tools/cgeist/CMakeLists.txt b/tools/cgeist/CMakeLists.txt index 30dd13d116b7..c0663093a7c5 100644 --- a/tools/cgeist/CMakeLists.txt +++ b/tools/cgeist/CMakeLists.txt @@ -59,8 +59,6 @@ target_compile_definitions(cgeist PUBLIC -DLLVM_OBJ_ROOT="${LLVM_BINARY_DIR}") target_link_libraries(cgeist PRIVATE MLIRSCFTransforms MLIRPolygeist - MLIRSQL - MLIRSupport MLIRIR MLIRAnalysis @@ -77,7 +75,6 @@ target_link_libraries(cgeist PRIVATE MLIRMathToLLVM MLIRTargetLLVMIRImport MLIRPolygeistTransforms - MLIRSQLTransforms MLIRLLVMToLLVMIRTranslation MLIRSCFToOpenMP MLIROpenMPToLLVM @@ -110,6 +107,11 @@ target_link_libraries(cgeist PRIVATE clangLex clangSerialization ) + +if (ENABLE_SQL) + target_link_libraries(cgeist PRIVATE MLIRSQLTransforms MLIRSQL) +endif() + add_dependencies(cgeist MLIRPolygeistOpsIncGen MLIRPolygeistPassIncGen) add_dependencies(cgeist MLIRSQLOpsIncGen MLIRSQLPassIncGen) add_subdirectory(Test) diff --git a/tools/polygeist-opt/CMakeLists.txt b/tools/polygeist-opt/CMakeLists.txt index c2fd7c41ec4c..3221aa842d0d 100644 --- a/tools/polygeist-opt/CMakeLists.txt +++ b/tools/polygeist-opt/CMakeLists.txt @@ -7,6 +7,10 @@ set(LIBS MLIRPolygeist MLIRPolygeistTransforms ) +if (ENABLE_SQL) + set(LIBS ${LIBS} MLIRSQLTransforms MLIRSQL) +endif() + add_llvm_executable(polygeist-opt polygeist-opt.cpp) install(TARGETS polygeist-opt