diff --git a/CMakeLists.txt b/CMakeLists.txt index 59e116c6414e..3635652592b6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -2,8 +2,7 @@ cmake_minimum_required(VERSION 3.10) include(CheckCXXSourceCompiles) -set(POLYGEIST_ENABLE_CUDA 0 CACHE BOOL "Enable CUDA frontend and backend") -set(POLYGEIST_ENABLE_ROCM 0 CACHE BOOL "Enable ROCM backend") +set(POLYGEIST_ENABLE_CUDA 0 CACHE BOOL "Enable CUDA compilation support") if(POLICY CMP0068) cmake_policy(SET CMP0068 NEW) @@ -21,16 +20,11 @@ endif() option(LLVM_INCLUDE_TOOLS "Generate build targets for the LLVM tools." ON) option(LLVM_BUILD_TOOLS "Build the LLVM tools. If OFF, just generate build targets." ON) +option(ENABLE_SQL "Build SQL dialect" OFF) + set(LLVM_RUNTIME_OUTPUT_INTDIR ${CMAKE_BINARY_DIR}/bin) set(LLVM_LIBRARY_OUTPUT_INTDIR ${CMAKE_BINARY_DIR}/lib) -find_program(XXD_BIN xxd) - -# TODO should depend on OS -set(POLYGEIST_PGO_DEFAULT_DATA_DIR "/var/tmp/polygeist/pgo/" CACHE STRING "Directory for PGO data") -set(POLYGEIST_PGO_ALTERNATIVE_ENV_VAR "POLYGEIST_PGO_ALTERNATIVE" CACHE STRING "Env var name to specify alternative to profile") -set(POLYGEIST_PGO_DATA_DIR_ENV_VAR "POLYGEIST_PGO_DATA_DIR" CACHE STRING "Env var name to specify PGO data dir") - if (CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR) project(polygeist LANGUAGES CXX C) @@ -112,6 +106,40 @@ set(LLVM_LIT_ARGS "-sv" CACHE STRING "lit default options") list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake/modules") include(sanitizers) +if (ENABLE_SQL) + include(FetchContent) + include(ExternalProject) + + FetchContent_Declare(sqlparser_ext + GIT_REPOSITORY https://github.com/wsmoses/sql-parser + GIT_TAG c2471248cef8cd33081e698e8ac65d691283dbd4 + ) + + FetchContent_GetProperties(sqlparser_ext) + + FetchContent_MakeAvailable(sqlparser_ext) + + ExternalProject_Add(sqlparser + SOURCE_DIR ${sqlparser_ext_SOURCE_DIR} + INSTALL_DIR ${CMAKE_CURRENT_BINARY_DIR}/sql/install + CONFIGURE_COMMAND "" + BUILD_COMMAND ${CMAKE_COMMAND} -E env + CXX=${CMAKE_CXX_COMPILER} + make static=yes -C ${sqlparser_ext_SOURCE_DIR} + BUILD_IN_SOURCE TRUE + INSTALL_COMMAND "" + BUILD_BYPRODUCTS ${sqlparser_ext_SOURCE_DIR}/libsqlparser.a + ) + + + add_library(sqlparse_lib INTERFACE) + + target_include_directories(sqlparse_lib INTERFACE "${sqlparser_ext_SOURCE_DIR}/src") + target_link_libraries(sqlparse_lib INTERFACE ${sqlparser_ext_SOURCE_DIR}/libsqlparser.a) + add_dependencies(sqlparse_lib sqlparser) + +endif() + add_subdirectory(include) add_subdirectory(lib) add_subdirectory(tools) diff --git a/include/CMakeLists.txt b/include/CMakeLists.txt index da66b9bf293f..7719ec0de2aa 100644 --- a/include/CMakeLists.txt +++ b/include/CMakeLists.txt @@ -1 +1,4 @@ add_subdirectory(polygeist) +if (ENABLE_SQL) + add_subdirectory(sql) +endif() \ No newline at end of file diff --git a/include/sql/CMakeLists.txt b/include/sql/CMakeLists.txt new file mode 100644 index 000000000000..1ee2b045b2cc --- /dev/null +++ b/include/sql/CMakeLists.txt @@ -0,0 +1,4 @@ +add_mlir_dialect(SQLOps sql) +# add_mlir_doc(SQLDialect -gen-dialect-doc SQLDialect SQL/) +# add_mlir_doc(SQLOps -gen-op-doc SQLOps SQL/) +add_subdirectory(Passes) \ No newline at end of file diff --git a/include/sql/Parser.h b/include/sql/Parser.h new file mode 100644 index 000000000000..d7ec2f5a1e37 --- /dev/null +++ b/include/sql/Parser.h @@ -0,0 +1,28 @@ +//===- Parser.h - SQL dialect -----------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + + +#ifndef SQLPARSER_H +#define SQLPARSER_H + +#include "mlir/IR/Dialect.h" + +#include "mlir/IR/Value.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/Attributes.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/BuiltinTypes.h" + +#include "sql/SQLDialect.h" +#include "sql/SQLOps.h" +#include "sql/SQLTypes.h" + +mlir::Value parseSQL(mlir::Location loc, mlir::OpBuilder& builder, std::string str); + +#endif // SQLPARSER_H \ No newline at end of file diff --git a/include/sql/Passes/CMakeLists.txt b/include/sql/Passes/CMakeLists.txt new file mode 100644 index 000000000000..298ddabf48f7 --- /dev/null +++ b/include/sql/Passes/CMakeLists.txt @@ -0,0 +1,5 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name sql) +add_public_tablegen_target(MLIRSQLPassIncGen) + +add_mlir_doc(Passes SQLPasses ./ -gen-pass-doc) diff --git a/include/sql/Passes/Passes.h b/include/sql/Passes/Passes.h new file mode 100644 index 000000000000..82bc1667b8dd --- /dev/null +++ b/include/sql/Passes/Passes.h @@ -0,0 +1,51 @@ +#ifndef SQL_DIALECT_SQL_PASSES_H +#define SQL_DIALECT_SQL_PASSES_H + +#include "mlir/Conversion/LLVMCommon/LoweringOptions.h" +#include "mlir/Pass/Pass.h" +#include +namespace mlir { +class PatternRewriter; +class RewritePatternSet; +class DominanceInfo; +namespace sql { + +std::unique_ptr createSQLLowerPass(); +std::unique_ptr createSQLRaisingPass(); +} // namespace sql +} // namespace mlir + + + +namespace mlir { +// Forward declaration from Dialect.h +template +void registerDialect(DialectRegistry ®istry); + +namespace arith { +class ArithDialect; +} // end namespace arith + +namespace scf { +class SCFDialect; +} // end namespace scf + +namespace memref { +class MemRefDialect; +} // end namespace memref + +namespace func { +class FuncDialect; +} + +class AffineDialect; +namespace LLVM { +class LLVMDialect; +} + +#define GEN_PASS_REGISTRATION +#include "sql/Passes/Passes.h.inc" + +} // end namespace mlir + +#endif // SQL_DIALECT_SQL_PASSES_H diff --git a/include/sql/Passes/Passes.td b/include/sql/Passes/Passes.td new file mode 100644 index 000000000000..11968e1415fe --- /dev/null +++ b/include/sql/Passes/Passes.td @@ -0,0 +1,22 @@ +#ifndef SQL_PASSES +#define SQL_PASSES + +include "mlir/Pass/PassBase.td" + + +def SQLLower : Pass<"sql-lower", "mlir::ModuleOp"> { + let summary = "Lower sql op to mlir"; + let dependentDialects = + ["arith::ArithDialect", "func::FuncDialect", "LLVM::LLVMDialect"]; + let constructor = "mlir::sql::createSQLLowerPass()"; +} + + +def SQLRaising : Pass<"sql-raising", "mlir::ModuleOp"> { + let summary = "Raise sql op to mlir"; + let dependentDialects = + ["arith::ArithDialect", "func::FuncDialect", "LLVM::LLVMDialect"]; + let constructor = "mlir::sql::createSQLRaisingPass()"; +} + +#endif // SQL_PASSES diff --git a/include/sql/Passes/Utils.h b/include/sql/Passes/Utils.h new file mode 100644 index 000000000000..b191d89ba6a6 --- /dev/null +++ b/include/sql/Passes/Utils.h @@ -0,0 +1,139 @@ +#pragma once + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/IntegerSet.h" + +static inline mlir::scf::IfOp +cloneWithResults(mlir::scf::IfOp op, mlir::OpBuilder &rewriter, + mlir::BlockAndValueMapping mapping = {}) { + using namespace mlir; + return rewriter.create(op.getLoc(), op.getResultTypes(), + mapping.lookupOrDefault(op.getCondition()), + true); +} +static inline mlir::AffineIfOp +cloneWithResults(mlir::AffineIfOp op, mlir::OpBuilder &rewriter, + mlir::BlockAndValueMapping mapping = {}) { + using namespace mlir; + SmallVector lower; + for (auto o : op.getOperands()) + lower.push_back(mapping.lookupOrDefault(o)); + return rewriter.create(op.getLoc(), op.getResultTypes(), + op.getIntegerSet(), lower, true); +} + +static inline mlir::scf::IfOp +cloneWithoutResults(mlir::scf::IfOp op, mlir::OpBuilder &rewriter, + mlir::BlockAndValueMapping mapping = {}, + mlir::TypeRange types = {}) { + using namespace mlir; + return rewriter.create( + op.getLoc(), types, mapping.lookupOrDefault(op.getCondition()), true); +} +static inline mlir::AffineIfOp +cloneWithoutResults(mlir::AffineIfOp op, mlir::OpBuilder &rewriter, + mlir::BlockAndValueMapping mapping = {}, + mlir::TypeRange types = {}) { + using namespace mlir; + SmallVector lower; + for (auto o : op.getOperands()) + lower.push_back(mapping.lookupOrDefault(o)); + return rewriter.create(op.getLoc(), types, op.getIntegerSet(), + lower, true); +} + +static inline mlir::scf::ForOp +cloneWithoutResults(mlir::scf::ForOp op, mlir::PatternRewriter &rewriter, + mlir::BlockAndValueMapping mapping = {}) { + using namespace mlir; + return rewriter.create( + op.getLoc(), mapping.lookupOrDefault(op.getLowerBound()), + mapping.lookupOrDefault(op.getUpperBound()), + mapping.lookupOrDefault(op.getStep())); +} +static inline mlir::AffineForOp +cloneWithoutResults(mlir::AffineForOp op, mlir::PatternRewriter &rewriter, + mlir::BlockAndValueMapping mapping = {}) { + using namespace mlir; + SmallVector lower; + for (auto o : op.getLowerBoundOperands()) + lower.push_back(mapping.lookupOrDefault(o)); + SmallVector upper; + for (auto o : op.getUpperBoundOperands()) + upper.push_back(mapping.lookupOrDefault(o)); + return rewriter.create(op.getLoc(), lower, op.getLowerBoundMap(), + upper, op.getUpperBoundMap(), + op.getStep()); +} + +static inline void clearBlock(mlir::Block *block, + mlir::PatternRewriter &rewriter) { + for (auto &op : llvm::make_early_inc_range(llvm::reverse(*block))) { + assert(op.use_empty() && "expected 'op' to have no uses"); + rewriter.eraseOp(&op); + } +} + +static inline mlir::Block *getThenBlock(mlir::scf::IfOp op) { + return op.thenBlock(); +} +static inline mlir::Block *getThenBlock(mlir::AffineIfOp op) { + return op.getThenBlock(); +} +static inline mlir::Block *getElseBlock(mlir::scf::IfOp op) { + return op.elseBlock(); +} +static inline mlir::Block *getElseBlock(mlir::AffineIfOp op) { + if (op.hasElse()) + return op.getElseBlock(); + else + return nullptr; +} + +static inline mlir::Region &getThenRegion(mlir::scf::IfOp op) { + return op.getThenRegion(); +} +static inline mlir::Region &getThenRegion(mlir::AffineIfOp op) { + return op.getThenRegion(); +} +static inline mlir::Region &getElseRegion(mlir::scf::IfOp op) { + return op.getElseRegion(); +} +static inline mlir::Region &getElseRegion(mlir::AffineIfOp op) { + return op.getElseRegion(); +} + +static inline mlir::scf::YieldOp getThenYield(mlir::scf::IfOp op) { + return op.thenYield(); +} +static inline mlir::AffineYieldOp getThenYield(mlir::AffineIfOp op) { + return llvm::cast(op.getThenBlock()->getTerminator()); +} +static inline mlir::scf::YieldOp getElseYield(mlir::scf::IfOp op) { + return op.elseYield(); +} +static inline mlir::AffineYieldOp getElseYield(mlir::AffineIfOp op) { + return llvm::cast(op.getElseBlock()->getTerminator()); +} + +static inline bool inBound(mlir::scf::IfOp op, mlir::Value v) { + return op.getCondition() == v; +} +static inline bool inBound(mlir::AffineIfOp op, mlir::Value v) { + return llvm::any_of(op.getOperands(), [&](mlir::Value e) { return e == v; }); +} +static inline bool inBound(mlir::scf::ForOp op, mlir::Value v) { + return op.getUpperBound() == v; +} +static inline bool inBound(mlir::AffineForOp op, mlir::Value v) { + return llvm::any_of(op.getUpperBoundOperands(), + [&](mlir::Value e) { return e == v; }); +} +static inline bool hasElse(mlir::scf::IfOp op) { + return op.getElseRegion().getBlocks().size() > 0; +} +static inline bool hasElse(mlir::AffineIfOp op) { + return op.getElseRegion().getBlocks().size() > 0; +} diff --git a/include/sql/SQLDialect.h b/include/sql/SQLDialect.h new file mode 100644 index 000000000000..bdb0178a0662 --- /dev/null +++ b/include/sql/SQLDialect.h @@ -0,0 +1,16 @@ +//===- SQLDialect.h - SQL dialect -----------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef SQL_DIALECT_H +#define SQL_DIALECT_H + +#include "mlir/IR/Dialect.h" + +#include "sql/SQLOpsDialect.h.inc" + +#endif \ No newline at end of file diff --git a/include/sql/SQLDialect.td b/include/sql/SQLDialect.td new file mode 100644 index 000000000000..8026cd4268cf --- /dev/null +++ b/include/sql/SQLDialect.td @@ -0,0 +1,41 @@ +//===- SQLDialect.td - SQL dialect -----------*- tablegen -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef SQL_DIALECT +#define SQL_DIALECT + + + +include "mlir/IR/FunctionInterfaces.td" +include "mlir/IR/SymbolInterfaces.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Interfaces/CallInterfaces.td" +include "mlir/Interfaces/CastInterfaces.td" + + +def SQL_Dialect : Dialect { + let summary = "A dialect for SQL languages in MLIR."; + let description = [{ + TBD + }]; + let name = "sql"; + let cppNamespace = "::mlir::sql"; + + let useDefaultTypePrinterParser = 1; + let extraClassDeclaration = [{ + void registerTypes(); + }]; +} + + +#endif // SQL_DIALECT + + + + + diff --git a/include/sql/SQLOps.h b/include/sql/SQLOps.h new file mode 100644 index 000000000000..1b5c6f4783ec --- /dev/null +++ b/include/sql/SQLOps.h @@ -0,0 +1,23 @@ +//===- SQLOps.h - SQL dialect ops --------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef SQLOPS_H +#define SQLOPS_H + +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Interfaces/ViewLikeInterface.h" +#include "llvm/Support/CommandLine.h" +#include "sql/SQLTypes.h" +#define GET_OP_CLASSES +#include "sql/SQLOps.h.inc" + +#endif \ No newline at end of file diff --git a/include/sql/SQLOps.td b/include/sql/SQLOps.td new file mode 100644 index 000000000000..8191f71a6673 --- /dev/null +++ b/include/sql/SQLOps.td @@ -0,0 +1,243 @@ +//===- SQLOps.td - SQL dialect ops ----------------*- tablegen -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef SQL_OPS +#define SQL_OPS + +include "mlir/IR/OpBase.td" +include "SQLDialect.td" +include "SQLTypes.td" + + + +class SQL_Op traits = []> + : Op; + + +def IntOp : SQL_Op<"int", [Pure]> { + let summary = "int op"; + + let arguments = (ins StrAttr:$expr); + let results = (outs SQLExprType:$result); + + let hasFolder = 0; + let hasCanonicalizer = 0; +} + +def ColumnOp : SQL_Op<"column", [Pure]> { + let summary = "column op"; + + let arguments = (ins StrAttr:$expr); + let results = (outs SQLExprType:$result); + + let hasFolder = 0; + let hasCanonicalizer = 0; +} + +def WhereOp: SQL_Op<"where", [Pure]> { + let summary = "where op"; + + let arguments = (ins SQLBoolType:$expr); + let results = (outs SQLExprType:$result); + + let hasFolder = 0; + let hasCanonicalizer = 0; +} + +def BoolArithOp: SQL_Op<"bool_arith", [Pure]> { + let summary = "bool_arith op"; + + let arguments = (ins SQLBoolType:$left, SQLBoolType:$right, StrAttr:$op); + let results = (outs SQLBoolType:$result); + + let hasFolder = 0; + let hasCanonicalizer = 0; +} + +def CalcBoolOp: SQL_Op<"calc_bool", [Pure]> { + let summary = "calc_bool op"; + + let arguments = (ins StrAttr:$expr); + let results = (outs SQLBoolType:$result); + + let hasFolder = 0; + let hasCanonicalizer = 0; +} + + + + +def ArithOp: SQL_Op<"arith", [Pure]> { + let summary = "arith op"; + + let arguments = (ins SQLExprType:$left, SQLExprType:$right, StrAttr:$op); + let results = (outs SQLExprType:$result); + + let hasFolder = 0; + let hasCanonicalizer = 0; +} + +def AndOp: SQL_Op<"and", [Pure]> { + let summary = "and op"; + + let arguments = (ins SQLBoolType:$left, SQLBoolType:$right); + let results = (outs SQLBoolType:$result); + + let hasFolder = 0; + let hasCanonicalizer = 0; +} + +def OrOp: SQL_Op<"or", [Pure]> { + let summary = "or op"; + + let arguments = (ins SQLBoolType:$left, SQLBoolType:$right); + let results = (outs SQLBoolType:$result); + + let hasFolder = 0; + let hasCanonicalizer = 0; +} + +def TableOp : SQL_Op<"table", [Pure]> { + let summary = "table"; + + let arguments = (ins StrAttr:$expr); + let results = (outs SQLExprType:$result); + + let hasFolder = 0; + let hasCanonicalizer = 0; +} + + +def SQLConstantStringOp : SQL_Op<"str_constant", [Pure]> { + let summary = "str_constant"; + + let arguments = (ins StrAttr:$input); + let results = (outs AnyType:$result); + + let hasFolder = 0; + let hasCanonicalizer = 0; +} + +def SQLToStringOp : SQL_Op<"to_string", [Pure]> { + let summary = "to_string"; + + let arguments = (ins SQLExprType:$expr); + let results = (outs AnyType:$result); + + let hasFolder = 0; + let hasCanonicalizer = 0; +} + + +def SQLBoolToStringOp : SQL_Op<"bool_to_string", [Pure]> { + let summary = "bool_to_string"; + + let arguments = (ins SQLBoolType:$expr); + let results = (outs AnyType:$result); + + let hasFolder = 0; + let hasCanonicalizer = 0; +} + + +def SQLStringConcatOp : SQL_Op<"string_concat", [Pure]> { + let summary = "string_concat"; + + let arguments = (ins Variadic:$expr); + let results = (outs AnyType:$result); + + let hasFolder = 0; + let hasCanonicalizer = 1; +} + +def ConstantBoolOp : SQL_Op <"constant_bool", [Pure]> { + let summary = "constant_bool"; + let results = (outs SQLBoolType:$result); +} + + +def SelectOp : SQL_Op<"select", [Pure]> { + let summary = "select"; + // i need to specify the size of a Variadic? + let arguments = (ins Variadic:$columns, + SQLExprType:$table, + SQLExprType:$where, + SI64Attr:$limit); + // attribute limit if >= 0 then its the real thing, otherwise its infinity + let results = (outs SQLExprType:$result); + + let hasFolder = 0; + let hasCanonicalizer = 0; +} + +def AllColumnsOp : SQL_Op<"all_columns", [Pure]> { + let summary = "all_columns"; + let results = (outs SQLExprType:$result); + + let hasFolder = 0; + let hasCanonicalizer = 0; +} + + + + +def UnparsedOp : SQL_Op<"unparsed", [Pure]> { + let summary = "unparsed sql op"; + + let arguments = (ins AnyType:$input); + let results = (outs SQLExprType:$result); + + let hasFolder = 0; + let hasCanonicalizer = 1; +} + +def ExecuteOp : SQL_Op<"execute", []> { + let summary = "execute query"; + + let arguments = (ins Index:$conn, SQLExprType:$command); + let results = (outs Index:$result); + + let hasFolder = 0; + let hasCanonicalizer = 0; +} + + +def NumResultsOp : SQL_Op<"num_results", [Pure]> { + let summary = "number of results"; + + let arguments = (ins Index:$handle); + let results = (outs Index:$result); + + let hasFolder = 0; + let hasCanonicalizer = 1; +} + +def GetValueOp : SQL_Op<"get_value", [Pure]> { + let summary = "get value of execution"; + + let arguments = (ins Index:$handle, Index:$column, Index:$row); + let results = (outs AnyType:$result); + + let hasFolder = 0; + let hasCanonicalizer = 1; +} + +// def SelectOp : SQL_Op<"select", [Pure]>{ +// let summary = "select"; + +// let arguments = (ins StrArrayAttr:$columns, +// Optional:$from); +// // optional>:$where, +// // optional:$limit, +// // optional:$order); +// let results = (outs AnyType:$result); + +// let hasFolder = 0; +// let hasCanonicalizer = 0; +// } +#endif // SQL_OPS \ No newline at end of file diff --git a/include/sql/SQLTypes.h b/include/sql/SQLTypes.h new file mode 100644 index 000000000000..28f9d91e393a --- /dev/null +++ b/include/sql/SQLTypes.h @@ -0,0 +1,18 @@ +//===- SQLTypes.h - SQL dialect types --------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef SQL_SQLTYPES_H +#define SQL_SQLTYPES_H + +#include "mlir/IR/BuiltinTypes.h" + +#define GET_TYPEDEF_CLASSES +#include "sql/SQLOpsTypes.h.inc" + + +#endif // SQL_SQLTYPES_H \ No newline at end of file diff --git a/include/sql/SQLTypes.td b/include/sql/SQLTypes.td new file mode 100644 index 000000000000..a8fc3f0e30d6 --- /dev/null +++ b/include/sql/SQLTypes.td @@ -0,0 +1,40 @@ +//===- SQLTypes.td - SQL dialect types ----------------*- tablegen -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef SQL_TYPES +#define SQL_TYPES + +include "mlir/IR/AttrTypeBase.td" +include "SQLDialect.td" + +class SQL_Type traits = []> + : TypeDef { + let mnemonic = typeMnemonic; +} + + + +def SQLExprType : SQL_Type<"Expr", "expr"> { + let summary = "SQL expression type"; + let description = "Custom attr or value type in sql dialect"; + + // placeholder params + // let parameters = (ins StringRefParameter<"the custom value">:$value); + // let assemblyFormat = "`<` $value `>`"; +} + +def SQLBoolType : SQL_Type<"Bool", "bool"> { + let summary = "SQL boolean type"; + let description = "Custom attr or value type in sql dialect"; + + // placeholder params + // let parameters = (ins StringRefParameter<"the custom value">:$value); + // let assemblyFormat = "`<` $value `>`"; +} + +#endif // SQL_TYPES \ No newline at end of file diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index da66b9bf293f..7719ec0de2aa 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -1 +1,4 @@ add_subdirectory(polygeist) +if (ENABLE_SQL) + add_subdirectory(sql) +endif() \ No newline at end of file diff --git a/lib/sql/CMakeLists.txt b/lib/sql/CMakeLists.txt new file mode 100644 index 000000000000..64a2f92db852 --- /dev/null +++ b/lib/sql/CMakeLists.txt @@ -0,0 +1,20 @@ +add_mlir_dialect_library(MLIRSQL +Types.cpp +Dialect.cpp +Ops.cpp +Parser.cpp +NewParser.cpp + + +ADDITIONAL_HEADER_DIRS +${PROJECT_SOURCE_DIR}/include/sql + +DEPENDS +MLIRSQLOpsIncGen +# MLIRSQLTypesIncGen + +LINK_LIBS PUBLIC +MLIRIR sqlparse_lib +) + +add_subdirectory(Passes) \ No newline at end of file diff --git a/lib/sql/Dialect.cpp b/lib/sql/Dialect.cpp new file mode 100644 index 000000000000..104e510bf59d --- /dev/null +++ b/lib/sql/Dialect.cpp @@ -0,0 +1,30 @@ +//===- SQLDialect.cpp - SQL dialect ---------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/DialectImplementation.h" +#include "sql/SQLDialect.h" +#include "sql/SQLOps.h" +#include "sql/SQLTypes.h" + +using namespace mlir; +using namespace mlir::sql; + +//===----------------------------------------------------------------------===// +// SQL dialect. +//===----------------------------------------------------------------------===// + +void SQLDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "sql/SQLOps.cpp.inc" + >(); +registerTypes(); + +} + +#include "sql/SQLOpsDialect.cpp.inc" diff --git a/lib/sql/Ops.cpp b/lib/sql/Ops.cpp new file mode 100644 index 000000000000..4f2c2b34cddb --- /dev/null +++ b/lib/sql/Ops.cpp @@ -0,0 +1,293 @@ +//===- SQLOps.cpp - SQL dialect ops ---------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#include +#include +#include +#include + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" + +#include "polygeist/Ops.h" +#include "sql/Parser.h" +#include "sql/SQLDialect.h" +#include "sql/SQLOps.h" +#include "sql/SQLTypes.h" + +#define GET_OP_CLASSES +#include "sql/SQLOps.cpp.inc" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/OpenMP/OpenMPDialect.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/IntegerSet.h" +#include "mlir/Transforms/SideEffectUtils.h" + +#include "llvm/ADT/SetVector.h" +#include "llvm/Support/Debug.h" + +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/Value.h" +#include "llvm/ADT/SmallVector.h" + +#define DEBUG_TYPE "sql" + +using namespace mlir; +using namespace sql; +using namespace mlir::arith; + +class GetValueOpTypeFix final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(GetValueOp op, + PatternRewriter &rewriter) const override { + + bool changed = false; + + Value handle = op.getOperand(0); + if (!handle.getType().isa()) { + handle = rewriter.create(op.getLoc(), + rewriter.getIndexType(), handle); + changed = true; + } + Value row = op.getOperand(1); + if (!row.getType().isa()) { + row = rewriter.create(op.getLoc(), rewriter.getIndexType(), + row); + changed = true; + } + Value column = op.getOperand(2); + if (!column.getType().isa()) { + column = rewriter.create(op.getLoc(), + rewriter.getIndexType(), column); + changed = true; + } + + if (!changed) + return failure(); + + rewriter.replaceOpWithNewOp(op, op.getType(), handle, row, + column); + + return success(changed); + } +}; + +void GetValueOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.insert(context); +} + +class NumResultsOpTypeFix final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(NumResultsOp op, + PatternRewriter &rewriter) const override { + bool changed = false; + Value handle = op->getOperand(0); + + if (handle.getType().isa() && + op->getResultTypes()[0].isa()) + return failure(); + + if (!handle.getType().isa()) { + handle = rewriter.create(op.getLoc(), + rewriter.getIndexType(), handle); + changed = true; + } + + mlir::Value res = rewriter.create( + op.getLoc(), rewriter.getIndexType(), handle); + + if (op->getResultTypes()[0].isa()) { + rewriter.replaceOp(op, res); + } else { + rewriter.replaceOpWithNewOp(op, op->getResultTypes()[0], + res); + } + + return success(changed); + } +}; + +void NumResultsOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.insert(context); +} + +// class ExecuteOpTypeFix final : public OpRewritePattern { +// public: +// using OpRewritePattern::OpRewritePattern; + +// LogicalResult matchAndRewrite(ExecuteOp op, +// PatternRewriter &rewriter) const override { +// bool changed = false; + +// Value conn = op->getOperand(0); +// Value command = op->getOperand(1); + +// if (conn.getType().isa() && command.getType().isa() +// && op->getResultTypes()[0].isa()) +// return failure(); + +// if (!conn.getType().isa()) { +// conn = rewriter.create(op.getLoc(), +// rewriter.getIndexType(), +// conn); +// changed = true; +// } +// if (command.getType().isa()) { +// command = rewriter.create(op.getLoc(), +// LLVM::LLVMPointerType::get(rewriter.getI8Type()), +// command); +// changed = true; +// } + +// if (command.getType().isa()) { +// command = rewriter.create(op.getLoc(), +// rewriter.getI64Type(), +// command); +// changed = true; +// } +// if (!command.getType().isa()) { +// command = rewriter.create(op.getLoc(), +// rewriter.getIndexType(), +// command); +// changed = true; +// } + +// if (!changed) return failure(); +// mlir::Value res = rewriter.create(op.getLoc(), +// rewriter.getIndexType(), conn, command); rewriter.replaceOp(op, res); +// // if (op->getResultTypes()[0].isa()) { +// // rewriter.replaceOp(op, res); +// // } else { +// // rewriter.replaceOpWithNewOp(op, +// op->getResultTypes()[0], res); +// // } +// return success(changed); +// } +// }; + +// void ExecuteOp::getCanonicalizationPatterns(RewritePatternSet &results, +// MLIRContext *context) { +// results.insert(context); +// } + +template +class UnparsedOpInnerCast final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(UnparsedOp op, + PatternRewriter &rewriter) const override { + + Value input = op->getOperand(0); + + auto cst = input.getDefiningOp(); + if (!cst) + return failure(); + + rewriter.replaceOpWithNewOp(op, op.getType(), cst.getOperand()); + return success(); + } +}; + +void UnparsedOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.insert>(context); +} + +class SQLStringConcatOpCanonicalization final + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(SQLStringConcatOp op, + PatternRewriter &rewriter) const override { + // Whether we changed the state. If we make no simplifications we need to + // return failure otherwise we will infinite loop + bool changed = false; + // Operands to the simplified concat + SmallVector operands; + // Constants that we will merge, "current running constant" + SmallVector constants; + for (auto op : op->getOperands()) { + if (auto constOp = op.getDefiningOp()) { + constants.push_back(constOp); + continue; + } + if (constants.size() != 0) { + if (constants.size() == 1) { + operands.push_back(constants[0]); + } else { + std::string nextStr; + changed = true; + for (auto str : constants) + nextStr += str.getInput().str(); + + operands.push_back(rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), nextStr)); + } + } + constants.clear(); + if (auto concat = op.getDefiningOp()) { + changed = true; + for (auto op2 : concat->getOperands()) + operands.push_back(op2); + continue; + } + operands.push_back(op); + } + if (constants.size() != 0) { + if (constants.size() == 1) { + operands.push_back(constants[0]); + } else { + std::string nextStr; + changed = true; + for (auto str : constants) + nextStr = nextStr + str.getInput().str(); + operands.push_back(rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), nextStr)); + } + } + if (operands.size() == 0) { + rewriter.replaceOpWithNewOp(op, MemRefType::get({-1}, rewriter.getI8Type()), ""); + return success(); + } + if (operands.size() == 1) { + rewriter.replaceOp(op, operands[0]); + return success(); + } + if (changed) { + rewriter.replaceOpWithNewOp(op, MemRefType::get({-1}, rewriter.getI8Type()), operands); + return success(); + } + return failure(); + } +}; + +void SQLStringConcatOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.insert(context); +} diff --git a/lib/sql/Parser.cpp b/lib/sql/Parser.cpp new file mode 100644 index 000000000000..86125e1b7fc1 --- /dev/null +++ b/lib/sql/Parser.cpp @@ -0,0 +1,320 @@ +#include +#include +#include +#include + +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/Value.h" +#include "llvm/ADT/SmallVector.h" + +#include "sql/SQLDialect.h" +#include "sql/SQLOps.h" +#include "sql/SQLTypes.h" + +using namespace mlir; +using namespace sql; +enum class ParseType { + Nothing = 0, + Value = 1, + Attribute = 2, +}; + +struct ParseValue { +private: + ParseType ty; + Value value; + Attribute attr; + +public: + ParseValue() : ty(ParseType::Nothing), value(nullptr), attr(nullptr) {} + ParseValue(Value value) : ty(ParseType::Value), value(value), attr(nullptr) { + assert(value); + } + ParseValue(Attribute attr) + : ty(ParseType::Attribute), value(nullptr), attr(attr) {} + + ParseType getType() const { return ty; } + Value getValue() const { + assert(ty == ParseType::Value); + assert(value); + return value; + } + Attribute getAttr() const { + assert(ty == ParseType::Attribute); + return attr; + } +}; + +enum class ParseMode { + None, + Column, + Table, + Bool, + Clause +}; + +template +std::ostream &operator<<( + typename std::enable_if::value, std::ostream>::type &stream, + const T &e) { + return stream << static_cast::type>(e); +} + +class SQLParser { + +public: + Location loc; + OpBuilder &builder; + std::string sql; + unsigned int i; + + static std::vector reservedWords; + + SQLParser(Location loc, OpBuilder &builder, std::string sql, int i) + : loc(loc), builder(builder), sql(sql), i(i) {} + + std::string peek() { + auto [peeked, _] = peekWithLength(); + return peeked; + } + + std::string pop() { + auto [peeked, len] = peekWithLength(); + i += len; + popWhitespace(); + return peeked; + } + + void popWhitespace() { + // it doesn't recognize + while (i < sql.size() && + (sql[i] == ' ' || sql[i] == '\n' || sql[i] == '\t')) { + i++; + } + } + + std::pair peekWithLength() { + if (i >= sql.size()) { + return {"", 0}; + } + for (std::string rWord : reservedWords) { + auto token = sql.substr(i, std::min(sql.size() - i, rWord.size())); + std::transform(token.begin(), token.end(), token.begin(), ::toupper); + if (token == rWord) { + return {token, static_cast(token.size())}; + } + } + if (sql[i] == '\'') { // Quoted string + return peekQuotedStringWithLength(); + } + return peekIdentifierWithLength(); + } + + std::pair peekQuotedStringWithLength() { + if (sql.size() < i || sql[i] != '\'') { + return {"", 0}; + } + for (unsigned int j = i + 1; j < sql.size(); j++) { + if (sql[j] == '\'' && sql[j - 1] != '\\') { + return {sql.substr(i + 1, j - (i + 1)), + j - i + 2}; // +2 for the two quotes + } + } + return {"", 0}; + } + + std::pair peekIdentifierWithLength() { + std::regex e("[a-zA-Z0-9_*]"); + for (unsigned int j = i; j < sql.size(); j++) { + if (!std::regex_match(std::string(1, sql[j]), e)) { + return {sql.substr(i, j - i), j - i}; + } + } + return {sql.substr(i), static_cast(sql.size()) - i}; + } + + bool is_number(std::string *s) { + std::string::const_iterator it = s->begin(); + while (it != s->end() && std::isdigit(*it)) + ++it; + return !s->empty() && it == s->end(); + } + + // Parse the next command, if any + ParseValue parseNext(ParseMode mode) { + // for (unsigned int j = i; j < sql.size(); j++) { + // auto peekStr = peek(); + // pop(); + // llvm::errs() << "peekStrTest: " << i << " " << peekStr << "\n"; + // } + if (i >= sql.size()) { + return ParseValue(); + } + auto peekStr = peek(); + llvm::errs() << "peekStr: " << peekStr << "\n"; + assert(peekStr.size() > 0); + + if (peekStr == "SELECT") { + assert(mode == ParseMode::None || mode == ParseMode::Table); + pop(); + peekStr = peek(); + if (peekStr == "DISTINCT") { + pop(); + // do something different here + } + llvm::SmallVector columns; + bool hasColumns = true; + bool hasWhere = false; + int limit = -1; + Value table = nullptr; + Value where = nullptr; + while (true) { + peekStr = peek(); + if (peekStr == "") + break; + if (hasColumns) { + if (peekStr == "FROM") { + pop(); + table = parseNext(ParseMode::Table).getValue(); + hasColumns = false; + } else { + if (peekStr == ",") { + pop(); + continue; + } + ParseValue col = parseNext(ParseMode::Column); + if (col.getType() == ParseType::Nothing) { + hasColumns = false; + break; + } else { + columns.push_back(col.getValue()); + } + } + } else if (peekStr == "WHERE") { + pop(); + ParseValue clause = parseNext(ParseMode::Bool); + if (clause.getType() == ParseType::Nothing) { + assert(0 && "where clause not recognized"); + } else { + where = builder.create(loc, ExprType::get(builder.getContext()), + clause.getValue()).getResult(); + } + } else if (peekStr == "LIMIT") { + pop(); + peekStr = peek(); + llvm::errs() << "limit: " << peekStr << "\n"; + if (peekStr == "ALL") { + pop(); + } else if (is_number(&peekStr)) { + llvm::errs() << "limit recognized: " + << "\n"; + pop(); + limit = std::stoi(peekStr); + } else { + assert(0 && "not yet handled limit var"); + } + } else { + // break; + llvm::errs() << "peekstr that throws an error" << peekStr << "\n"; + assert(0 && " additional clauses like where/etc not yet handled"); + } + } + if (!table) { + llvm::errs() << " table is null: " << table << "\n"; + table = builder.create(loc, ExprType::get(builder.getContext()), + builder.getStringAttr("")).getResult(); + } + if (!where){ + llvm::errs() << " where is null: " << table << "\n"; + Value clause = builder.create(loc, BoolType::get(builder.getContext()), + builder.getStringAttr("")).getResult(); + where = builder.create(loc, ExprType::get(builder.getContext()), + clause).getResult(); + } + return ParseValue( + builder.create(loc, ExprType::get(builder.getContext()), + columns, table, where, limit).getResult()); + } else if (is_number(&peekStr)) { + pop(); + return ParseValue(builder.create(loc, + ExprType::get(builder.getContext()), + builder.getStringAttr(peekStr)).getResult()); + } else if (mode == ParseMode::Column) { + if (peekStr == "*") { + pop(); + return ParseValue( + builder + .create(loc, ExprType::get(builder.getContext())) + .getResult()); + } + pop(); + return ParseValue( + builder + .create(loc, ExprType::get(builder.getContext()), + builder.getStringAttr(peekStr)) + .getResult()); + } else if (mode == ParseMode::Table) { + pop(); + return ParseValue( + builder + .create(loc, ExprType::get(builder.getContext()), + builder.getStringAttr(peekStr)) + .getResult()); + } else if (mode == ParseMode::Bool) { + // col = peekStr; + ParseValue left = parseNext(ParseMode::Clause); + peekStr = peek(); + if (peekStr == "AND") { + pop(); + ParseValue right = parseNext(ParseMode::Bool); + return ParseValue( + builder.create(loc, BoolType::get(builder.getContext()), left.getValue(), right.getValue()).getResult() + ); + } else if (peekStr == "OR") { + pop(); + ParseValue right = parseNext(ParseMode::Bool); + return ParseValue( + builder.create(loc, BoolType::get(builder.getContext()), left.getValue(), + right.getValue()).getResult() + ); + } else return left; + + + } else if (mode == ParseMode::Clause){ + std::string clause = pop(); + clause += " " + pop(); + clause += " " + pop(); + return ParseValue( + builder.create(loc, BoolType::get(builder.getContext()), + builder.getStringAttr(clause)).getResult() + ); + } else if (peekStr == "(") { + pop(); + ParseValue res = parseNext(ParseMode::None); + assert(peek() == ")"); + pop(); + return res; + } else if (peekStr == ")") { + return ParseValue(); + } + llvm::errs() << " Unknown token to parse: " << peekStr << "\n"; + llvm_unreachable("Unknown token to parse"); + } +}; + +std::vector SQLParser::reservedWords = { + "(", ")", ">=", "<=", "!=", + ",", "=", ">", "<", ",", + "SELECT", "DISTINCT", "INSERT INTO", "VALUES", "UPDATE", + "DELETE FROM", "WHERE", "FROM", "SET", "AS"}; + +mlir::Value parseSQL(mlir::Location loc, mlir::OpBuilder &builder, + std::string str) { + SQLParser parser(loc, builder, std::string(str), 0); + auto resOp = parser.parseNext(ParseMode::None); + + return resOp.getValue(); +} \ No newline at end of file diff --git a/lib/sql/Passes/CMakeLists.txt b/lib/sql/Passes/CMakeLists.txt new file mode 100644 index 000000000000..ee492cf8f26c --- /dev/null +++ b/lib/sql/Passes/CMakeLists.txt @@ -0,0 +1,19 @@ +add_mlir_dialect_library(MLIRSQLTransforms + SQLLower.cpp + SQLRaising.cpp + + DEPENDS + MLIRPolygeistOpsIncGen + MLIRPolygeistPassIncGen + MLIRSQLPassIncGen + + LINK_LIBS PUBLIC + MLIRArithDialect + MLIRFuncDialect + MLIRFuncTransforms + MLIRIR + MLIRLLVMDialect + MLIRMathDialect + MLIRMemRefDialect + MLIRPass + ) \ No newline at end of file diff --git a/lib/sql/Passes/PassDetails.h b/lib/sql/Passes/PassDetails.h new file mode 100644 index 000000000000..1436f9e69809 --- /dev/null +++ b/lib/sql/Passes/PassDetails.h @@ -0,0 +1,39 @@ +//===- PassDetails.h - polygeist pass class details ----------------*- C++ +//-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Stuff shared between the different polygeist passes. +// +//===----------------------------------------------------------------------===// + +// clang-tidy seems to expect the absolute path in the header guard on some +// systems, so just disable it. +// NOLINTNEXTLINE(llvm-header-guard) +#ifndef DIALECT_POLYGEIST_TRANSFORMS_PASSDETAILS_H +#define DIALECT_POLYGEIST_TRANSFORMS_PASSDETAILS_H + +#include "mlir/Pass/Pass.h" +#include "sql/SQLOps.h" +#include "sql/Passes/Passes.h" + +namespace mlir { +class FunctionOpInterface; +// Forward declaration from Dialect.h +template +void registerDialect(DialectRegistry ®istry); +namespace sql { + +class SQLDialect; + +#define GEN_PASS_CLASSES +#include "sql/Passes/Passes.h.inc" + +} // namespace polygeist +} // namespace mlir + +#endif // DIALECT_POLYGEIST_TRANSFORMS_PASSDETAILS_H diff --git a/lib/sql/Passes/SQLLower.cpp b/lib/sql/Passes/SQLLower.cpp new file mode 100644 index 000000000000..eb0483e4bb32 --- /dev/null +++ b/lib/sql/Passes/SQLLower.cpp @@ -0,0 +1,527 @@ +//===- SQLLower.cpp - Lower PostgreSQL to sql mlir ops ------ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to lower gpu kernels in NVVM/gpu dialects into +// a generic SQL for representation +//===----------------------------------------------------------------------===// + +#include "PassDetails.h" +#include "mlir/Analysis/CallGraph.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Async/IR/Async.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" +#include "polygeist/Ops.h" +#include "sql/Passes/Passes.h" +#include "sql/SQLOps.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallPtrSet.h" +#include +#include + +#define DEBUG_TYPE "sql-lower-opt" + +using namespace mlir; +using namespace mlir::arith; +using namespace mlir::func; +using namespace mlir::sql; + +namespace { +struct SQLLower : public SQLLowerBase { + void runOnOperation() override; +}; + +} // end anonymous namespace + +struct NumResultsOpLowering : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(sql::NumResultsOp op, + PatternRewriter &rewriter) const final { + auto module = op->getParentOfType(); + + SymbolTableCollection symbolTable; + symbolTable.getSymbolTable(module); + + // 1) make sure the postgres_getresult function is declared + auto rowsfn = dyn_cast_or_null(symbolTable.lookupSymbolIn( + module, rewriter.getStringAttr("PQntuples"))); + + // 2) convert the args to valid args to postgres_getresult abi + Value arg = op.getHandle(); + arg = rewriter.create(op.getLoc(), rewriter.getI8Type(), + arg); + + arg = rewriter.create( + op.getLoc(), LLVM::LLVMPointerType::get(rewriter.getI8Type()), arg); + + arg = rewriter.create( + op.getLoc(), rowsfn.getFunctionType().getInput(0), arg); + + // 3) call and replace + Value args[] = {arg}; + + Value res = rewriter.create(op.getLoc(), rowsfn, args) + ->getResult(0); + + rewriter.replaceOpWithNewOp(op, rewriter.getIndexType(), + res); + + return success(); + } +}; + +struct GetValueOpLowering : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(sql::GetValueOp op, + PatternRewriter &rewriter) const final { + auto module = op->getParentOfType(); + + SymbolTableCollection symbolTable; + symbolTable.getSymbolTable(module); + + // 1) make sure the postgres_getresult function is declared + auto valuefn = dyn_cast_or_null(symbolTable.lookupSymbolIn( + module, rewriter.getStringAttr("PQgetvalue"))); + + auto atoifn = dyn_cast_or_null( + symbolTable.lookupSymbolIn(module, rewriter.getStringAttr("atoi"))); + + // 2) convert the args to valid args to postgres_getresult abi + Value handle = op.getHandle(); + handle = rewriter.create(op.getLoc(), + rewriter.getI64Type(), handle); + handle = rewriter.create( + op.getLoc(), LLVM::LLVMPointerType::get(rewriter.getI8Type()), handle); + + handle = rewriter.create( + op.getLoc(), valuefn.getFunctionType().getInput(0), handle); + + Value row = op.getRow(); + row = rewriter.create( + op.getLoc(), valuefn.getFunctionType().getInput(1), row); + Value column = op.getColumn(); + column = rewriter.create( + op.getLoc(), valuefn.getFunctionType().getInput(2), column); + + Value args[] = {handle, row, column}; + + Value res = rewriter.create(op.getLoc(), valuefn, args) + ->getResult(0); + + Value args2[] = {res}; + + Value res2 = + rewriter.create(op.getLoc(), atoifn, args2) + ->getResult(0); + + if (op.getType() != res2.getType()) { + if (op.getType().isa()) + res2 = rewriter.create(op.getLoc(), op.getType(), + res2); + else if (auto IT = op.getType().dyn_cast()) { + auto IT2 = res2.getType().dyn_cast(); + if (IT.getWidth() < IT2.getWidth()) { + res2 = + rewriter.create(op.getLoc(), op.getType(), res2); + } else if (IT.getWidth() > IT2.getWidth()) { + res2 = + rewriter.create(op.getLoc(), op.getType(), res2); + } else + assert(0 && "illegal integer type conversion"); + } else { + assert(0 && "illegal type conversion"); + } + } + rewriter.replaceOp(op, res2); + + return success(); + } +}; + +struct ConstantStringOpLowering + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(sql::SQLConstantStringOp op, + PatternRewriter &rewriter) const final { + auto module = op->getParentOfType(); + + SymbolTableCollection symbolTable; + symbolTable.getSymbolTable(module); + for (auto u : op.getResult().getUsers()) { + if (isa(u)) + return failure(); + } + auto expr = op.getInput().str(); + auto name = "str" + std::to_string((long long int)(Operation *)op); + auto MT = MemRefType::get({expr.size() + 1}, rewriter.getI8Type()); + // auto type = MemRefType::get(mt.getShape(), mt.getElementType(), {}); + auto getglob = rewriter.create(op.getLoc(), MT, name); + + SmallVector data(expr.begin(), expr.end()); + data.push_back('\0'); + auto attr = DenseElementsAttr::get( + RankedTensorType::get(MT.getShape(), MT.getElementType()), data); + + auto loc = op.getLoc(); + rewriter.replaceOpWithNewOp( + op, MemRefType::get({-1}, rewriter.getI8Type()), getglob.getResult()); + rewriter.setInsertionPointToStart(module.getBody()); + auto res = rewriter.create( + loc, rewriter.getStringAttr(name), mlir::StringAttr(), + mlir::TypeAttr::get(MT), attr, rewriter.getUnitAttr(), + /*alignment*/ nullptr); + + return success(); + } +}; + + +struct BoolToStringOpLowering : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(sql::SQLBoolToStringOp op, + PatternRewriter &rewriter) const final { + auto module = op->getParentOfType(); + + SymbolTableCollection symbolTable; + symbolTable.getSymbolTable(module); + + // 2) convert the args to valid args to postgres_getresult abi + Value expr = op.getExpr(); + auto definingOp = expr.getDefiningOp(); + if (auto andOp = dyn_cast(definingOp)) { + Value left = rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), + andOp.getLeft()); + Value right = rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), + andOp.getRight()); + Value args[] = {left, rewriter.create( + op.getLoc(), + MemRefType::get({-1}, rewriter.getI8Type()), + "AND "), right}; + Value res = rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), args); + rewriter.replaceOp(op, res); + } else if (auto orOp = dyn_cast(definingOp)) { + Value left = rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), + orOp.getLeft()); + Value right = rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), + orOp.getRight()); + Value args[] = {left, rewriter.create( + op.getLoc(), + MemRefType::get({-1}, rewriter.getI8Type()), + "OR "), right}; + Value res = rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), args); + rewriter.replaceOp(op, res); + } else if (auto calcBoolOp = dyn_cast(definingOp)){ + auto expr = calcBoolOp.getExpr(); + if (expr == "") { + Value res = rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), + ""); + rewriter.replaceOp(op, res); + return success(); + } + + Value res = rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), expr); + Value args[] = {res, rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), " ")}; + res = rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), args); + rewriter.replaceOp(op, res); + } else { + assert(0 && "unknown type to convert to string"); + } + + return success(); + } +}; + + +struct ToStringOpLowering : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(sql::SQLToStringOp op, + PatternRewriter &rewriter) const final { + auto module = op->getParentOfType(); + + SymbolTableCollection symbolTable; + symbolTable.getSymbolTable(module); + + // 2) convert the args to valid args to postgres_getresult abi + Value expr = op.getExpr(); + auto definingOp = expr.getDefiningOp(); + if (auto selectOp = dyn_cast(definingOp)) { + Value current = rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), "SELECT "); + + bool prevColumn = false; + auto columns = selectOp.getColumns(); + for (mlir::Value v : columns) { + if (prevColumn) { + Value args[] = { + current, rewriter.create( + op.getLoc(), + MemRefType::get({-1}, rewriter.getI8Type()), ", ")}; + current = rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), args); + } + Value col = rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), v); + Value args[] = { + col, + rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), " ")}; + col = rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), args); + Value args2[] = {current, col}; + current = rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), args2); + prevColumn = true; + } + + auto tableOp = selectOp.getTable().getDefiningOp(); + if (tableOp) { + Value args[] = { + current, rewriter.create( + op.getLoc(), + MemRefType::get({-1}, rewriter.getI8Type()), "FROM ")}; + current = rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), args); + Value args2[] = {current, + rewriter.create( + op.getLoc(), + MemRefType::get({-1}, rewriter.getI8Type()), + selectOp.getTable())}; + current = rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), args2); + } + + Value whereVal = rewriter.create( + op.getLoc(), + MemRefType::get({-1}, rewriter.getI8Type()), + selectOp.getWhere()); + if (whereVal) { + auto whereOp = selectOp.getWhere().getDefiningOp(); + Value args[] = { + current, rewriter.create( + op.getLoc(), + MemRefType::get({-1}, rewriter.getI8Type()), "WHERE ")}; + current = rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), args); + Value args2[] = {current, whereVal}; + current = rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), args2); + } + + + auto limit = selectOp.getLimit(); + if (limit >= 0) { + Value args[] = {current, + rewriter.create( + op.getLoc(), + MemRefType::get({-1}, rewriter.getI8Type()), + "LIMIT ")}; + current = rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), args); + Value args2[] = {current, + rewriter.create( + op.getLoc(), + MemRefType::get({-1}, rewriter.getI8Type()), + std::to_string(limit))}; + current = rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), args2); + } + rewriter.replaceOp(op, current); + } else if (auto tabOp = dyn_cast(definingOp)) { + Value res = rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), + tabOp.getExpr()); + Value args[] = { res, rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), " ")}; + res = rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), args); + rewriter.replaceOp(op, res); + } else if (auto allColOp = dyn_cast(definingOp)) { + Value res = rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), "*"); + rewriter.replaceOp(op, res); + } else if (auto colOp = dyn_cast(definingOp)) { + Value res = rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), + colOp.getExpr()); + rewriter.replaceOp(op, res); + } else if (auto intOp = dyn_cast(definingOp)) { + Value res = rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), + intOp.getExpr()); + llvm::errs() << "intOp: " << intOp.getExpr() << "\n"; + rewriter.replaceOp(op, res); + } else if (auto whereOp = dyn_cast(definingOp)) { + Value res = rewriter.create( + op.getLoc(), + MemRefType::get({-1}, rewriter.getI8Type()), + whereOp.getExpr()); + rewriter.replaceOp(op, res); + } else { + assert(0 && "unknown type to convert to string"); + } + + return success(); + } +}; + + + +struct ExecuteOpLowering : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(sql::ExecuteOp op, + PatternRewriter &rewriter) const final { + auto module = op->getParentOfType(); + + SymbolTableCollection symbolTable; + symbolTable.getSymbolTable(module); + + // 1) make sure the postgres_getresult function is declared + auto executefn = dyn_cast_or_null( + symbolTable.lookupSymbolIn(module, rewriter.getStringAttr("PQexec"))); + + // 2) convert the args to valid args to postgres_getresult abi + Value conn = op.getConn(); + conn = rewriter.create(op.getLoc(), + rewriter.getI8Type(), conn); + conn = rewriter.create( + op.getLoc(), LLVM::LLVMPointerType::get(rewriter.getI8Type()), conn); + conn = rewriter.create( + op.getLoc(), executefn.getFunctionType().getInput(0), conn); + + Value command = op.getCommand(); + // auto name = "str" + std::to_string((long long int)(Operation + // *)command.getDefiningOp()); + command = rewriter.create( + op.getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), command); + llvm::errs() << "command: " << command << "\n"; + llvm::errs() << "command type: " << command.getType() << "\n"; + // auto type = MemRefType::get({-1}, rewriter.getI8Type()); + + // auto globalOp = rewriter.create( + // op.getLoc(), type, /* isConstant */ true, LLVM::Linkage::Internal, + // name, mlir::Attribute(), + // /* alignment */ 0, /* addrSpace */ 0); + + // 3) call and replace + Value args[] = {conn, command}; + + Value res = + rewriter.create(op.getLoc(), executefn, args) + ->getResult(0); + res = rewriter.create( + op.getLoc(), LLVM::LLVMPointerType::get(rewriter.getI8Type()), res); + res = rewriter.create(op.getLoc(), rewriter.getI64Type(), + res); + res = rewriter.create(op.getLoc(), op.getType(), res); + + rewriter.replaceOp(op, res); + + return success(); + } +}; + +void SQLLower::runOnOperation() { + auto module = getOperation(); + + SymbolTableCollection symbolTable; + symbolTable.getSymbolTable(module); + OpBuilder builder(module.getContext()); + builder.setInsertionPointToStart(module.getBody()); + + if (!dyn_cast_or_null(symbolTable.lookupSymbolIn( + module, builder.getStringAttr("PQntuples")))) { + mlir::Type argtypes[] = {MemRefType::get({-1}, builder.getI8Type())}; + mlir::Type rettypes[] = {builder.getI64Type()}; + + auto fn = builder.create( + module.getLoc(), "PQntuples", + builder.getFunctionType(argtypes, rettypes)); + SymbolTable::setSymbolVisibility(fn, SymbolTable::Visibility::Private); + } + + if (!dyn_cast_or_null(symbolTable.lookupSymbolIn( + module, builder.getStringAttr("PQgetvalue")))) { + mlir::Type argtypes[] = {MemRefType::get({-1}, builder.getI8Type()), + builder.getI64Type(), builder.getI64Type()}; + mlir::Type rettypes[] = {MemRefType::get({-1}, builder.getI8Type())}; + + auto fn = builder.create( + module.getLoc(), "PQgetvalue", + builder.getFunctionType(argtypes, rettypes)); + SymbolTable::setSymbolVisibility(fn, SymbolTable::Visibility::Private); + } + + if (!dyn_cast_or_null(symbolTable.lookupSymbolIn( + module, builder.getStringAttr("PQexec")))) { + mlir::Type argtypes[] = {MemRefType::get({-1}, builder.getI8Type()), + MemRefType::get({-1}, builder.getI8Type())}; + mlir::Type rettypes[] = {MemRefType::get({-1}, builder.getI8Type())}; + + auto fn = builder.create( + module.getLoc(), "PQexec", builder.getFunctionType(argtypes, rettypes)); + SymbolTable::setSymbolVisibility(fn, SymbolTable::Visibility::Private); + } + + if (!dyn_cast_or_null( + symbolTable.lookupSymbolIn(module, builder.getStringAttr("atoi")))) { + mlir::Type argtypes[] = {MemRefType::get({-1}, builder.getI8Type())}; + // mlir::Type argtypes[] = + // {LLVM::LLVMPointerType::get(builder.getI64Type())}; + + // todo use data layout + mlir::Type rettypes[] = {builder.getI64Type()}; + + auto fn = builder.create( + module.getLoc(), "atoi", builder.getFunctionType(argtypes, rettypes)); + SymbolTable::setSymbolVisibility(fn, SymbolTable::Visibility::Private); + } + + RewritePatternSet patterns(&getContext()); + patterns.insert(&getContext()); + patterns.insert(&getContext()); + patterns.insert(&getContext()); + patterns.insert(&getContext()); + patterns.insert(&getContext()); + patterns.insert(&getContext()); + + for (auto *dialect : getContext().getLoadedDialects()) + dialect->getCanonicalizationPatterns(patterns); + for (RegisteredOperationName op : getContext().getRegisteredOperations()) + op.getCanonicalizationPatterns(patterns, &getContext()); + + GreedyRewriteConfig config; + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), + config); +} + +namespace mlir { +namespace sql { +std::unique_ptr createSQLLowerPass() { + return std::make_unique(); +} +} // namespace sql +} // namespace mlir diff --git a/lib/sql/Passes/SQLRaising.cpp b/lib/sql/Passes/SQLRaising.cpp new file mode 100644 index 000000000000..0c0352c06328 --- /dev/null +++ b/lib/sql/Passes/SQLRaising.cpp @@ -0,0 +1,224 @@ +//===- SQLLower.cpp - Lower PostgreSQL to sql mlir ops ------ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to lower gpu kernels in NVVM/gpu dialects into +// a generic SQL for representation +//===----------------------------------------------------------------------===// + +#include "PassDetails.h" +#include "polygeist/Ops.h" +#include "mlir/Analysis/CallGraph.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Async/IR/Async.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "sql/SQLOps.h" +#include "sql/Parser.h" +#include "sql/Passes/Passes.h" +#include +#include + +#define DEBUG_TYPE "sql-raising-opt" + +using namespace mlir; +using namespace mlir::arith; +using namespace mlir::func; +using namespace mlir::sql; + +namespace { +struct SQLRaising : public SQLRaisingBase { + void runOnOperation() override; +}; + +} // end anonymous namespace + +struct PQntuplesRaising : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(func::CallOp call, + PatternRewriter &rewriter) const final { + if (call.getCallee() != "PQntuples") { + return failure(); + } + auto module = call->getParentOfType(); + SymbolTableCollection symbolTable; + symbolTable.getSymbolTable(module); + + + // 2) convert the args to valid args to postgres_getresult abi + Value arg = call.getArgOperands()[0]; + arg = rewriter.create( + call.getLoc(), LLVM::LLVMPointerType::get(rewriter.getI8Type()), arg); + arg = rewriter.create(call.getLoc(), rewriter.getIntegerType(64), arg); + arg = rewriter.create(call.getLoc(), rewriter.getIndexType(), arg); + + + Value res = rewriter.create(call.getLoc(), rewriter.getIndexType(), arg); + res = rewriter.create(call.getLoc(), + rewriter.getI64Type(), res); + rewriter.replaceOp(call, res); + return success(); + } +}; + +struct PQgetvalueRaising : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(func::CallOp call, + PatternRewriter &rewriter) const final { + if (call.getCallee() != "PQgetvalue") { + return failure(); + } + auto module = call->getParentOfType(); + SymbolTableCollection symbolTable; + symbolTable.getSymbolTable(module); + + // 2) convert the args to valid args to postgres_getresult abi + Value handle = call.getArgOperands()[0]; + handle = rewriter.create( + call.getLoc(), LLVM::LLVMPointerType::get(rewriter.getI8Type()), handle); + handle = rewriter.create(call.getLoc(), rewriter.getIntegerType(64), handle); + handle = rewriter.create(call.getLoc(), rewriter.getIndexType(), handle); + + Value row = call.getArgOperands()[1]; + row = rewriter.create(call.getLoc(), rewriter.getIndexType(), row); + Value column = call.getArgOperands()[2]; + column = rewriter.create(call.getLoc(), rewriter.getIndexType(), column); + + Value res = rewriter.create(call.getLoc(), rewriter.getIndexType(), handle, row, column); + + res = rewriter.create(call.getLoc(), + rewriter.getI64Type(), res); + + Value args2[] = {res}; + + + auto itoafn = dyn_cast_or_null( + symbolTable.lookupSymbolIn(module, rewriter.getStringAttr("itoa"))); + + rewriter.replaceOpWithNewOp(call, itoafn, args2); + + return success(); + } +}; + + +struct PQexecRaising : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(func::CallOp call, + PatternRewriter &rewriter) const final { + if (call.getCallee() != "PQexec") { + return failure(); + } + + // 2) convert the args to valid args to postgres_getresult abi + Value conn = call.getArgOperands()[0]; + if (!conn.getType().isa()) { + conn = rewriter.create( + call.getLoc(), LLVM::LLVMPointerType::get(rewriter.getI8Type()), conn); + } + conn = rewriter.create( + call.getLoc(), rewriter.getIntegerType(64), conn); + + conn = rewriter.create(call.getLoc(), + rewriter.getIndexType(), conn); + + Value command = call.getArgOperands()[1]; + + command = rewriter.create(call.getLoc(), + ExprType::get(rewriter.getContext()), command); + + Value res = rewriter.create(call.getLoc(), rewriter.getIndexType(), conn, command); + + res = rewriter.create(call.getLoc(), + rewriter.getI64Type(), res); + res = rewriter.create(call.getLoc(), LLVM::LLVMPointerType::get(rewriter.getI8Type()), res); + res = rewriter.create(call.getLoc(), call.getResultTypes()[0], res); + + + assert(call.getResultTypes()[0] == res.getType()); + rewriter.replaceOp(call, res); + + return success(); + } +}; + + +class UnparsedOpParseFix final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(UnparsedOp op, + PatternRewriter &rewriter) const override { + + Value input = op->getOperand(0); + + auto stringOp = input.getDefiningOp(); + if (!stringOp) return failure(); + + StringRef strname = stringOp.getGlobalName(); + + auto module = op->getParentOfType(); + SymbolTableCollection symbolTable; + symbolTable.getSymbolTable(module); + + Attribute strattr = dyn_cast_or_null( + symbolTable.lookupSymbolIn(module, rewriter.getStringAttr(strname))).getValueAttr(); + auto str = strattr.cast().getValue(); + + auto resOp = parseSQL(op.getLoc(), rewriter, std::string(str.data())); + assert(resOp); + rewriter.replaceOp(op,ValueRange(resOp)); + return success(); + } +}; + +void SQLRaising::runOnOperation() { + auto module = getOperation(); + SymbolTableCollection symbolTable; + symbolTable.getSymbolTable(module); + auto &context = getContext(); + OpBuilder builder(&context); + builder.setInsertionPointToStart(module.getBody()); + + if (!dyn_cast_or_null( + symbolTable.lookupSymbolIn(module, builder.getStringAttr("itoa")))) { + mlir::Type argtypes[] = {builder.getI64Type()}; + mlir::Type rettypes[] = {MemRefType::get({-1}, builder.getI8Type())}; + + auto fn = builder.create( + module.getLoc(), "itoa", builder.getFunctionType(argtypes, rettypes)); + SymbolTable::setSymbolVisibility(fn, SymbolTable::Visibility::Private); + } + + RewritePatternSet patterns(&context); + patterns.insert(&getContext()); + + for (auto *dialect : context.getLoadedDialects()) + dialect->getCanonicalizationPatterns(patterns); + for (RegisteredOperationName op : context.getRegisteredOperations()) + op.getCanonicalizationPatterns(patterns, &context); + + GreedyRewriteConfig config; + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), + config); +} + +namespace mlir { +namespace sql { +std::unique_ptr createSQLRaisingPass() { + return std::make_unique(); +} +} // namespace polygeist +} // namespace mlir diff --git a/lib/sql/Types.cpp b/lib/sql/Types.cpp new file mode 100644 index 000000000000..56d682068602 --- /dev/null +++ b/lib/sql/Types.cpp @@ -0,0 +1,37 @@ +//===- SQLTypes.cpp - SQL dialect types ---------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" + +#include "sql/SQLDialect.h" +#include "llvm/ADT/TypeSwitch.h" +#include "mlir/IR/DialectImplementation.h" +#include "sql/SQLTypes.h" + + +#define DEBUG_TYPE "sql" + +using namespace mlir::sql; + +#define GET_TYPEDEF_CLASSES +#include "sql/SQLOpsTypes.cpp.inc" + + +void SQLDialect::registerTypes() { + addTypes< +#define GET_TYPEDEF_LIST +#include "sql/SQLOpsTypes.cpp.inc" + >(); +} \ No newline at end of file diff --git a/test/polygeist-opt/sql.mlir b/test/polygeist-opt/sql.mlir new file mode 100644 index 000000000000..024215c97895 --- /dev/null +++ b/test/polygeist-opt/sql.mlir @@ -0,0 +1,11 @@ +// RUN: polygeist-opt %s | FileCheck %s +// -lower-sql +module { + func.func private @run() -> i32 { + %c0 = arith.constant 0 : index + %q = "sql.select"() {column = ["data"], table = "mytable"} : () -> index + %h = "sql.execute"(%q) : (index) -> index + %res = "sql.get_result"(%h, %c0) {column = "data"} : (index, index) -> i32 + return %res : i32 + } +} \ No newline at end of file diff --git a/test_with_pragma.c b/test_with_pragma.c new file mode 100644 index 000000000000..7c144f328c5b --- /dev/null +++ b/test_with_pragma.c @@ -0,0 +1,57 @@ +#include +#include +#include + +// PGresult *PQexec(PGconn*, const char* command); +// PQgetvalue +// %7 = call @PQexec(%2, %6) : (memref, memref) -> +// memref +// #pragma lower_to(num_rows_fn, "sql.num_results") +// int num_rows_fn(size_t);// char* + +// #pragma lower_to(get_value_fn_int, "sql.get_value") +// int get_value_fn_int(size_t, int, int); + +// #pragma lower_to(get_value_fn_double, "sql.get_value") +// double get_value_fn_double(size_t, int, int); + +// #pragma lower_to(execute, "sql.execute") +// PGresult* execute(size_t, char*); + +void do_exit(PGconn *conn) { + PQfinish(conn); + exit(1); +} + +int main() { + + PGconn *conn = PQconnectdb("user=carl dbname=testdb"); + + if (PQstatus(conn) == CONNECTION_BAD) { + + fprintf(stderr, "Connection to database failed: %s\n", + PQerrorMessage(conn)); + do_exit(conn); + } + + // PGresult *res = PQexec(conn, "SELECT 17"); + PGresult *res = PQexec(conn, "SELECT a FROM table1"); + PGresult *res1 = PQexec(conn, "SELECT * FROM table1 WHERE b > 10 OR c < 10 AND a <= 20"); + PGresult *res2 = PQexec(conn, "SELECT * FROM table1 WHERE b > 10 AND c < 10"); + PGresult *res3 = PQexec(conn, "SELECT b, c FROM table1 WHERE a <= 10 LIMIT 10"); + // PGresult *res3 = PQexec(conn, "SELECT b, c FROM table1 LIMIT ALL"); + if (PQresultStatus(res) != PGRES_TUPLES_OK) { + + printf("No data retrieved\n"); + PQclear(res); + do_exit(conn); + } + + PQclear(res); + PQclear(res1); + PQclear(res2); + PQclear(res3); + PQfinish(conn); + + return 0; +} diff --git a/tools/cgeist/CMakeLists.txt b/tools/cgeist/CMakeLists.txt index 1b0e7434c773..c0663093a7c5 100644 --- a/tools/cgeist/CMakeLists.txt +++ b/tools/cgeist/CMakeLists.txt @@ -59,7 +59,6 @@ target_compile_definitions(cgeist PUBLIC -DLLVM_OBJ_ROOT="${LLVM_BINARY_DIR}") target_link_libraries(cgeist PRIVATE MLIRSCFTransforms MLIRPolygeist - MLIRSupport MLIRIR MLIRAnalysis @@ -108,5 +107,11 @@ target_link_libraries(cgeist PRIVATE clangLex clangSerialization ) + +if (ENABLE_SQL) + target_link_libraries(cgeist PRIVATE MLIRSQLTransforms MLIRSQL) +endif() + add_dependencies(cgeist MLIRPolygeistOpsIncGen MLIRPolygeistPassIncGen) +add_dependencies(cgeist MLIRSQLOpsIncGen MLIRSQLPassIncGen) add_subdirectory(Test) diff --git a/tools/cgeist/Lib/clang-mlir.cc b/tools/cgeist/Lib/clang-mlir.cc index 9af4714326da..34dfad6dfa6d 100644 --- a/tools/cgeist/Lib/clang-mlir.cc +++ b/tools/cgeist/Lib/clang-mlir.cc @@ -5606,6 +5606,10 @@ mlir::Type MLIRASTConsumer::getMLIRType(clang::QualType qt, bool *implicitRef, return retTy; } + if (ST->isOpaque()) { + OpBuilder builder(module->getContext()); + types.push_back(builder.getIntegerType(8)); + } if (!types.size()) { RT->dump(); llvm::errs() << "ST: " << *ST << "\n"; diff --git a/tools/cgeist/driver.cc b/tools/cgeist/driver.cc index 4fdd907ddff1..e58c13fcddcc 100644 --- a/tools/cgeist/driver.cc +++ b/tools/cgeist/driver.cc @@ -73,7 +73,9 @@ #include #include "polygeist/Dialect.h" +#include "sql/SQLDialect.h" #include "polygeist/Passes/Passes.h" +#include "sql/Passes/Passes.h" #include @@ -550,6 +552,7 @@ int main(int argc, char **argv) { context.getOrLoadDialect(); context.getOrLoadDialect(); context.getOrLoadDialect(); + context.getOrLoadDialect(); LLVM::LLVMFunctionType::attachInterface(context); LLVM::LLVMPointerType::attachInterface(context); diff --git a/tools/polygeist-opt/CMakeLists.txt b/tools/polygeist-opt/CMakeLists.txt index ccfebd421d81..3221aa842d0d 100644 --- a/tools/polygeist-opt/CMakeLists.txt +++ b/tools/polygeist-opt/CMakeLists.txt @@ -6,8 +6,11 @@ set(LIBS MLIROptLib MLIRPolygeist MLIRPolygeistTransforms - MLIRFuncAllExtensions ) +if (ENABLE_SQL) + set(LIBS ${LIBS} MLIRSQLTransforms MLIRSQL) +endif() + add_llvm_executable(polygeist-opt polygeist-opt.cpp) install(TARGETS polygeist-opt diff --git a/tools/polygeist-opt/polygeist-opt.cpp b/tools/polygeist-opt/polygeist-opt.cpp index 95fe1b1fc4a4..60d1241c0d3e 100644 --- a/tools/polygeist-opt/polygeist-opt.cpp +++ b/tools/polygeist-opt/polygeist-opt.cpp @@ -33,6 +33,10 @@ #include "polygeist/Dialect.h" #include "polygeist/Passes/Passes.h" +#include "sql/SQLDialect.h" +#include "sql/SQLOps.h" +#include "sql/Passes/Passes.h" + using namespace mlir; class MemRefInsider @@ -62,8 +66,10 @@ int main(int argc, char **argv) { registry.insert(); registry.insert(); + registry.insert(); mlir::registerpolygeistPasses(); mlir::func::registerInlinerExtension(registry); + mlir::registersqlPasses(); // Register the standard passes we want. mlir::registerCSEPass();