From 9e3b416ba2e906724f1f8baf10bc9abdb8e3bf65 Mon Sep 17 00:00:00 2001 From: MahdiBM Date: Sun, 29 Oct 2023 17:25:02 +0330 Subject: [PATCH] initial commit --- Macros/PostgresRecordMacro/+SwiftSyntax.swift | 42 ++++++++ Macros/PostgresRecordMacro/Diagnoser.swift | 73 ++++++++++++++ Macros/PostgresRecordMacro/EntryPoint.swift | 13 +++ Macros/PostgresRecordMacro/MacroError.swift | 22 +++++ Macros/PostgresRecordMacro/ParsedType.swift | 80 +++++++++++++++ .../PostgresRecordMacroType.swift | 50 ++++++++++ .../PostgresRecordMacro/Variable+init.swift | 36 +++++++ Macros/PostgresRecordMacro/Variable.swift | 37 +++++++ Package.swift | 33 +++++-- Sources/PostgresKit/PostgresRecord.swift | 74 ++++++++++++++ .../MacroTests/PostgresRecordMacroTests.swift | 98 +++++++++++++++++++ 11 files changed, 551 insertions(+), 7 deletions(-) create mode 100644 Macros/PostgresRecordMacro/+SwiftSyntax.swift create mode 100644 Macros/PostgresRecordMacro/Diagnoser.swift create mode 100644 Macros/PostgresRecordMacro/EntryPoint.swift create mode 100644 Macros/PostgresRecordMacro/MacroError.swift create mode 100644 Macros/PostgresRecordMacro/ParsedType.swift create mode 100644 Macros/PostgresRecordMacro/PostgresRecordMacroType.swift create mode 100644 Macros/PostgresRecordMacro/Variable+init.swift create mode 100644 Macros/PostgresRecordMacro/Variable.swift create mode 100644 Sources/PostgresKit/PostgresRecord.swift create mode 100644 Tests/MacroTests/PostgresRecordMacroTests.swift diff --git a/Macros/PostgresRecordMacro/+SwiftSyntax.swift b/Macros/PostgresRecordMacro/+SwiftSyntax.swift new file mode 100644 index 0000000..c2658c2 --- /dev/null +++ b/Macros/PostgresRecordMacro/+SwiftSyntax.swift @@ -0,0 +1,42 @@ +import SwiftSyntax + +extension StructDeclSyntax { + func removingInheritedType(at idx: InheritedTypeListSyntax.Index) -> Self { + var new = self + new.inheritanceClause?.inheritedTypes.remove(at: idx) + /// Remove the colon after types name, in e.g. `MyTable: `, if no protocols are remaining. + if new.inheritanceClause?.inheritedTypes.isEmpty == true { + new.inheritanceClause = nil + } + return new.reformatted() + } + + var accessLevelModifier: String? { + let accessLevels: [Keyword] = [.open, .public, .package, .internal, .private, .fileprivate] + for modifier in self.modifiers { + guard let modifier = modifier.as(DeclModifierSyntax.self), + case let .keyword(keyword) = modifier.name.tokenKind else { + continue + } + if accessLevels.contains(keyword) { + return modifier.name.trimmedDescription + } + } + return nil + } + + /// https://github.com/apple/swift/pull/69448 + /// Remove whenever this bug-fix is live (when swift 5.9.2 is out?) + func conforms(to protocolName: String) -> Bool { + self.inheritanceClause?.inheritedTypes.contains { + $0.type.as(IdentifierTypeSyntax.self)?.name.trimmedDescription == protocolName + } == true + } +} + +extension SyntaxProtocol { + /// Build a syntax node from this `Buildable` and format it with the given format. + func reformatted() -> Self { + return self.formatted().as(Self.self)! + } +} diff --git a/Macros/PostgresRecordMacro/Diagnoser.swift b/Macros/PostgresRecordMacro/Diagnoser.swift new file mode 100644 index 0000000..2b21fcb --- /dev/null +++ b/Macros/PostgresRecordMacro/Diagnoser.swift @@ -0,0 +1,73 @@ +import SwiftSyntax +import SwiftSyntaxMacros +import SwiftDiagnostics + +struct Diagnoser { + let context: Context + + func cannotConformToProtocol( + name: String, + old: some SyntaxProtocol, + new: some SyntaxProtocol + ) { + let diagnosis = Diagnosis.cannotConformToProtocol(name) + context.diagnose(Diagnostic( + node: old, + position: old.position, + message: diagnosis.diagnosticMessage, + highlights: nil, + notes: [], + fixIt: .replace( + message: diagnosis.fixItMessage, + oldNode: old, + newNode: new + ) + )) + } +} + +private enum Diagnosis: Error { + case cannotConformToProtocol(String) + + private struct _DiagnosticMessage: DiagnosticMessage { + let parent: Diagnosis + + var message: String { + switch parent { + case let .cannotConformToProtocol(proto): + return "Simultaneous conformance to '\(proto)' is not supported" + } + } + + var diagnosticID: SwiftDiagnostics.MessageID { + .init(domain: "\(Self.self)", id: self.message) + } + + var severity: SwiftDiagnostics.DiagnosticSeverity { + .error + } + } + + private struct _FixItMessage: FixItMessage { + let parent: Diagnosis + + var message: String { + switch parent { + case let .cannotConformToProtocol(proto): + return "Remove conformance to '\(proto)'" + } + } + + var fixItID: SwiftDiagnostics.MessageID { + .init(domain: "\(Self.self)", id: self.message) + } + } + + var diagnosticMessage: any DiagnosticMessage { + _DiagnosticMessage(parent: self) + } + + var fixItMessage: any FixItMessage { + _FixItMessage(parent: self) + } +} diff --git a/Macros/PostgresRecordMacro/EntryPoint.swift b/Macros/PostgresRecordMacro/EntryPoint.swift new file mode 100644 index 0000000..78ed4cb --- /dev/null +++ b/Macros/PostgresRecordMacro/EntryPoint.swift @@ -0,0 +1,13 @@ +import SwiftSyntaxMacros +import SwiftCompilerPlugin + +@main +struct PostgresRecordMacroEntryPoint: CompilerPlugin { + static let macros: [String: any Macro.Type] = [ + "PostgresRecord": PostgresRecordMacroType.self + ] + + let providingMacros: [any Macro.Type] = macros.map(\.value) + + init() { } +} diff --git a/Macros/PostgresRecordMacro/MacroError.swift b/Macros/PostgresRecordMacro/MacroError.swift new file mode 100644 index 0000000..5a2932f --- /dev/null +++ b/Macros/PostgresRecordMacro/MacroError.swift @@ -0,0 +1,22 @@ +import SwiftDiagnostics + +enum MacroError: Error { + case isNotStruct +} + +extension MacroError: DiagnosticMessage { + var message: String { + switch self { + case .isNotStruct: + return "Only 'struct's are supported" + } + } + + var diagnosticID: MessageID { + .init(domain: "PostgresRecordMacro.MacroError", id: self.message) + } + + var severity: DiagnosticSeverity { + .error + } +} diff --git a/Macros/PostgresRecordMacro/ParsedType.swift b/Macros/PostgresRecordMacro/ParsedType.swift new file mode 100644 index 0000000..cbdb0e4 --- /dev/null +++ b/Macros/PostgresRecordMacro/ParsedType.swift @@ -0,0 +1,80 @@ +import SwiftSyntax + +indirect enum ParsedType: CustomStringConvertible { + + enum Error: Swift.Error, CustomStringConvertible { + case unknownParameterType(String) + case failedToParse(Any.Type) + + var description: String { + switch self { + case let .unknownParameterType(type): + return "unknownParameterType(\(type))" + case let .failedToParse(type): + return "failedToParse(\(type))" + } + } + } + + case plain(String) + case optional(of: Self) + case array(of: Self) + case dictionary(key: Self, value: Self) + case member(base: Self, `extension`: String) + case unknownGeneric(String, arguments: [Self]) + + public var description: String { + switch self { + case let .plain(type): + return type + case let .optional(type): + return "\(type)?" + case let .array(type): + return "[\(type)]" + case let .dictionary(key, value): + return "[\(key): \(value)]" + case let .member(base, `extension`): + return "\(base.description).\(`extension`)" + case let .unknownGeneric(name, arguments: arguments): + return "\(name)<\(arguments.map(\.description).joined(separator: ", "))>" + } + } + + public init(syntax: some TypeSyntaxProtocol) throws { + if let type = syntax.as(IdentifierTypeSyntax.self) { + let name = type.name.trimmedDescription + if let genericArgumentClause = type.genericArgumentClause, + !genericArgumentClause.arguments.isEmpty { + let arguments = genericArgumentClause.arguments + switch (arguments.count, name) { + case (1, "Optional"): + self = try .optional(of: Self(syntax: arguments.first!.argument)) + case (1, "Array"): + self = try .array(of: Self(syntax: arguments.first!.argument)) + case (2, "Dictionary"): + let key = try Self(syntax: arguments.first!.argument) + let value = try Self(syntax: arguments.last!.argument) + self = .dictionary(key: key, value: value) + default: + let arguments = try arguments.map(\.argument).map(Self.init(syntax:)) + self = .unknownGeneric(name, arguments: arguments) + } + } else { + self = .plain(name) + } + } else if let type = syntax.as(OptionalTypeSyntax.self) { + self = try .optional(of: Self(syntax: type.wrappedType)) + } else if let type = syntax.as(ArrayTypeSyntax.self) { + self = try .array(of: Self(syntax: type.element)) + } else if let type = syntax.as(DictionaryTypeSyntax.self) { + let key = try Self(syntax: type.key) + let value = try Self(syntax: type.value) + self = .dictionary(key: key, value: value) + } else if let type = syntax.as(MemberTypeSyntax.self) { + let kind = try Self(syntax: type.baseType) + self = .member(base: kind, extension: type.name.trimmedDescription) + } else { + throw Error.unknownParameterType(syntax.trimmed.description) + } + } +} diff --git a/Macros/PostgresRecordMacro/PostgresRecordMacroType.swift b/Macros/PostgresRecordMacro/PostgresRecordMacroType.swift new file mode 100644 index 0000000..8a3add9 --- /dev/null +++ b/Macros/PostgresRecordMacro/PostgresRecordMacroType.swift @@ -0,0 +1,50 @@ +import SwiftSyntax +import SwiftDiagnostics +import SwiftSyntaxMacros + +public enum PostgresRecordMacroType: ExtensionMacro { + public static func expansion( + of node: AttributeSyntax, + attachedTo declaration: some DeclGroupSyntax, + providingExtensionsOf type: some TypeSyntaxProtocol, + conformingTo protocols: [TypeSyntax], + in context: some MacroExpansionContext + ) throws -> [ExtensionDeclSyntax] { + if declaration.hasError { return [] } + let diagnoser = Diagnoser(context: context) + + guard let structDecl = declaration.as(StructDeclSyntax.self) else { + throw MacroError.isNotStruct + } + let accessLevel = structDecl.accessLevelModifier.map { "\($0) " } ?? "" + /// Compiler won't be able to infer what function to use when doing `PostgresRow.decode()`. + let forbiddenProtocols = ["PostgresCodable", "PostgresDecodable"] + let inheritedTypes = structDecl.inheritanceClause?.inheritedTypes ?? [] + for idx in inheritedTypes.indices { + let proto = inheritedTypes[idx] + let name = proto.trimmedDescription + if forbiddenProtocols.contains(name) { + diagnoser.cannotConformToProtocol( + name: name, + old: structDecl, + new: structDecl.removingInheritedType(at: idx) + ) + return [] + } + } + + let members = structDecl.memberBlock.members + let variableDecls = members.compactMap { $0.decl.as(VariableDeclSyntax.self) } + let variables = try variableDecls.flatMap(Variable.parse(from:)) + + let name = structDecl.name.trimmedDescription + let initializer = variables.makePostgresRecordInit(name: name, accessLevel: accessLevel) + let postgresRecord = try ExtensionDeclSyntax(""" + extension \(raw: name): PostgresRecord { + \(raw: initializer) + } + """) + + return [postgresRecord] + } +} diff --git a/Macros/PostgresRecordMacro/Variable+init.swift b/Macros/PostgresRecordMacro/Variable+init.swift new file mode 100644 index 0000000..37ad1bc --- /dev/null +++ b/Macros/PostgresRecordMacro/Variable+init.swift @@ -0,0 +1,36 @@ +import SwiftSyntax + +extension [Variable] { + func makePostgresRecordInit(name: String, accessLevel: String) -> String { + """ + \(accessLevel)init( + _from row: PostgresRow, + context: PostgresDecodingContext, + file: String, + line: Int + ) throws { + let decoded = try row.decode( + \(makeType()), + context: context, + file: file, + line: line + ) + \(makeInitializations()) + } + """ + } + + private func makeType() -> String { + "(\(self.map(\.type.description).joined(separator: ","))).self" + } + + private func makeInitializations() -> String { + if self.count == 1 { + return " self.\(self[0].name) = decoded" + } else { + return self.enumerated().map { (idx, variable) in + " self.\(variable.name) = decoded.\(idx)" + }.joined(separator: "\n") + } + } +} diff --git a/Macros/PostgresRecordMacro/Variable.swift b/Macros/PostgresRecordMacro/Variable.swift new file mode 100644 index 0000000..0933931 --- /dev/null +++ b/Macros/PostgresRecordMacro/Variable.swift @@ -0,0 +1,37 @@ +import SwiftSyntax + +struct Variable { + + enum Error: Swift.Error, CustomStringConvertible { + case unsupportedPattern(String) + case typeSyntaxNotFound(name: String) + + var description: String { + switch self { + case let .unsupportedPattern(pattern): + return "unsupportedPattern(\(pattern))" + case let .typeSyntaxNotFound(name): + return "typeSyntaxNotFound(name: \(name))" + } + } + } + + let name: String + let type: ParsedType + + static func parse(from element: VariableDeclSyntax) throws -> [Variable] { + try element.bindings.map { binding in + guard let pattern = binding.pattern.as(IdentifierPatternSyntax.self) else { + throw Error.unsupportedPattern(binding.pattern.trimmedDescription) + } + let name = pattern.identifier.trimmed.text + + guard let typeSyntax = binding.typeAnnotation?.type else { + throw Error.typeSyntaxNotFound(name: name) + } + let type = try ParsedType(syntax: typeSyntax) + + return Variable(name: name, type: type) + } + } +} diff --git a/Package.swift b/Package.swift index 3133ecb..d50d71a 100644 --- a/Package.swift +++ b/Package.swift @@ -1,13 +1,14 @@ -// swift-tools-version:5.7 +// swift-tools-version:5.9 import PackageDescription +import CompilerPluginSupport let package = Package( name: "postgres-kit", platforms: [ - .macOS(.v10_15), - .iOS(.v13), - .watchOS(.v6), - .tvOS(.v13), + .macOS(.v13), + .iOS(.v16), + .watchOS(.v9), + .tvOS(.v16), ], products: [ .library(name: "PostgresKit", targets: ["PostgresKit"]), @@ -16,18 +17,36 @@ let package = Package( .package(url: "https://github.com/vapor/postgres-nio.git", from: "1.14.2"), .package(url: "https://github.com/vapor/sql-kit.git", from: "3.28.0"), .package(url: "https://github.com/vapor/async-kit.git", from: "1.14.0"), - .package(url: "https://github.com/apple/swift-atomics.git", from: "1.1.0") + .package(url: "https://github.com/apple/swift-atomics.git", from: "1.1.0"), + .package(url: "https://github.com/apple/swift-syntax", from: "509.0.0"), ], targets: [ .target(name: "PostgresKit", dependencies: [ .product(name: "AsyncKit", package: "async-kit"), .product(name: "PostgresNIO", package: "postgres-nio"), .product(name: "SQLKit", package: "sql-kit"), - .product(name: "Atomics", package: "swift-atomics") + .product(name: "Atomics", package: "swift-atomics"), + .target(name: "PostgresRecordMacro"), ]), + .macro( + name: "PostgresRecordMacro", + dependencies: [ + .product(name: "SwiftSyntaxMacros", package: "swift-syntax"), + .product(name: "SwiftCompilerPlugin", package: "swift-syntax"), + .product(name: "PostgresNIO", package: "postgres-nio"), + ], + path: "./Macros/PostgresRecordMacro" + ), .testTarget(name: "PostgresKitTests", dependencies: [ .target(name: "PostgresKit"), .product(name: "SQLKitBenchmark", package: "sql-kit"), ]), + .testTarget( + name: "MacroTests", + dependencies: [ + .product(name: "SwiftSyntaxMacrosTestSupport", package: "swift-syntax"), + .target(name: "PostgresRecordMacro"), + ] + ), ] ) diff --git a/Sources/PostgresKit/PostgresRecord.swift b/Sources/PostgresKit/PostgresRecord.swift new file mode 100644 index 0000000..09732e1 --- /dev/null +++ b/Sources/PostgresKit/PostgresRecord.swift @@ -0,0 +1,74 @@ +import PostgresRecordMacro +import PostgresNIO + +/// Enables the type to be easily and performantly decoded from the a `PostgresRow`: +/// ```swift +/// @PostgresRecord +/// struct MyTable { +/// let one: Int +/// let two: String +/// } +/// let rows: [PostgresRow] = dbManager.sql(...) +/// let items: [MyTable] = try rows.map { row in +/// try row.decode(MyTable.self) +/// } +/// ``` +/// +/// WARNING: +/// The returned postgres row data must be in the same order as declared in the Swift type. +/// So basically make sure the order of the retrieved columns is the same order as the variables of the type. +@attached( + extension, + conformances: PostgresRecord, + names: named(init) +) +public macro PostgresRecord() = #externalMacro( + module: "PostgresRecordMacro", + type: "PostgresRecordMacroType" +) + +// MARK: PostgresRecord +public protocol PostgresRecord { + init( + _from row: PostgresRow, + context: PostgresDecodingContext, + file: String, + line: Int + ) throws +} + +// MARK: +PostgresRow +extension PostgresRow { + public func decode( + _ recordType: Record.Type = Record.self, + file: String = #fileID, + line: Int = #line + ) throws -> Record { + try Record.init( + _from: self, + context: .default, + file: file, + line: line + ) + } + + public func decode( + _ recordType: Record.Type = Record.self, + context: PostgresDecodingContext, + file: String = #fileID, + line: Int = #line + ) throws -> Record { + try Record.init( + _from: self, + context: context, + file: file, + line: line + ) + } +} + +#warning("to test") +@PostgresRecord +struct MyTable { + let thing: String +} diff --git a/Tests/MacroTests/PostgresRecordMacroTests.swift b/Tests/MacroTests/PostgresRecordMacroTests.swift new file mode 100644 index 0000000..9b573f3 --- /dev/null +++ b/Tests/MacroTests/PostgresRecordMacroTests.swift @@ -0,0 +1,98 @@ +@testable import PostgresRecordMacro +import SwiftSyntaxMacros +import SwiftSyntaxMacrosTestSupport +import XCTest + +final class PostgresRecordMacroTests: XCTestCase { + + func test() throws { + assertMacroExpansion(""" + @PostgresRecord + public struct MyTable { + let int: Int + let string: String? + } + """, + expandedSource: #""" + public struct MyTable { + let int: Int + let string: String? + } + + extension MyTable: PostgresRecord { + public init( + _from row: PostgresRow, + context: PostgresDecodingContext, + file: String, + line: Int + ) throws { + let decoded = try row.decode( + (Int, String?).self, + context: context, + file: file, + line: line + ) + self.int = decoded.0 + self.string = decoded.1 + } + } + """#, + macros: PostgresRecordMacroEntryPoint.macros + ) + } + + func testOnlyAllowsStructs() { + assertMacroExpansion(""" + @PostgresRecord + enum MyTable { + case a + } + """, + expandedSource: #""" + + enum MyTable { + case a + } + """#, + diagnostics: [ + .init( + message: "Only 'struct's are supported", + line: 1, + column: 1, + severity: .error + ) + ], + macros: PostgresRecordMacroEntryPoint.macros + ) + } + + func testDoesNotAllowSimultaneousConformance() { + assertMacroExpansion(""" + @PostgresRecord + struct MyTable: PostgresDecodable { + let thing: String + } + """, + expandedSource: #""" + + struct MyTable: PostgresDecodable { + let thing: String + } + """#, + diagnostics: [ + .init( + message: "Simultaneous conformance to 'PostgresDecodable' is not supported", + line: 1, + column: 1, + severity: .error, + highlight: nil, + notes: [], + fixIts: [.init( + message: "Remove conformance to 'PostgresDecodable'" + )] + ) + ], + macros: PostgresRecordMacroEntryPoint.macros + ) + } +}