Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
MahdiBM committed Oct 29, 2023
1 parent 80ab773 commit 9e3b416
Show file tree
Hide file tree
Showing 11 changed files with 551 additions and 7 deletions.
42 changes: 42 additions & 0 deletions Macros/PostgresRecordMacro/+SwiftSyntax.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import SwiftSyntax

extension StructDeclSyntax {
func removingInheritedType(at idx: InheritedTypeListSyntax.Index) -> Self {
var new = self
new.inheritanceClause?.inheritedTypes.remove(at: idx)
/// Remove the colon after types name, in e.g. `MyTable: `, if no protocols are remaining.
if new.inheritanceClause?.inheritedTypes.isEmpty == true {
new.inheritanceClause = nil
}
return new.reformatted()
}

var accessLevelModifier: String? {
let accessLevels: [Keyword] = [.open, .public, .package, .internal, .private, .fileprivate]
for modifier in self.modifiers {
guard let modifier = modifier.as(DeclModifierSyntax.self),
case let .keyword(keyword) = modifier.name.tokenKind else {
continue
}
if accessLevels.contains(keyword) {
return modifier.name.trimmedDescription
}
}
return nil
}

/// https://github.com/apple/swift/pull/69448
/// Remove whenever this bug-fix is live (when swift 5.9.2 is out?)
func conforms(to protocolName: String) -> Bool {
self.inheritanceClause?.inheritedTypes.contains {
$0.type.as(IdentifierTypeSyntax.self)?.name.trimmedDescription == protocolName
} == true
}
}

extension SyntaxProtocol {
/// Build a syntax node from this `Buildable` and format it with the given format.
func reformatted() -> Self {
return self.formatted().as(Self.self)!
}
}
73 changes: 73 additions & 0 deletions Macros/PostgresRecordMacro/Diagnoser.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import SwiftSyntax
import SwiftSyntaxMacros
import SwiftDiagnostics

struct Diagnoser<Context: MacroExpansionContext> {
let context: Context

func cannotConformToProtocol(
name: String,
old: some SyntaxProtocol,
new: some SyntaxProtocol
) {
let diagnosis = Diagnosis.cannotConformToProtocol(name)
context.diagnose(Diagnostic(
node: old,
position: old.position,
message: diagnosis.diagnosticMessage,
highlights: nil,
notes: [],
fixIt: .replace(
message: diagnosis.fixItMessage,
oldNode: old,
newNode: new
)
))
}
}

private enum Diagnosis: Error {
case cannotConformToProtocol(String)

private struct _DiagnosticMessage: DiagnosticMessage {
let parent: Diagnosis

var message: String {
switch parent {
case let .cannotConformToProtocol(proto):
return "Simultaneous conformance to '\(proto)' is not supported"
}
}

var diagnosticID: SwiftDiagnostics.MessageID {
.init(domain: "\(Self.self)", id: self.message)
}

var severity: SwiftDiagnostics.DiagnosticSeverity {
.error
}
}

private struct _FixItMessage: FixItMessage {
let parent: Diagnosis

var message: String {
switch parent {
case let .cannotConformToProtocol(proto):
return "Remove conformance to '\(proto)'"
}
}

var fixItID: SwiftDiagnostics.MessageID {
.init(domain: "\(Self.self)", id: self.message)
}
}

var diagnosticMessage: any DiagnosticMessage {
_DiagnosticMessage(parent: self)
}

var fixItMessage: any FixItMessage {
_FixItMessage(parent: self)
}
}
13 changes: 13 additions & 0 deletions Macros/PostgresRecordMacro/EntryPoint.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import SwiftSyntaxMacros
import SwiftCompilerPlugin

@main
struct PostgresRecordMacroEntryPoint: CompilerPlugin {
static let macros: [String: any Macro.Type] = [
"PostgresRecord": PostgresRecordMacroType.self
]

let providingMacros: [any Macro.Type] = macros.map(\.value)

init() { }
}
22 changes: 22 additions & 0 deletions Macros/PostgresRecordMacro/MacroError.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import SwiftDiagnostics

enum MacroError: Error {
case isNotStruct
}

extension MacroError: DiagnosticMessage {
var message: String {
switch self {
case .isNotStruct:
return "Only 'struct's are supported"
}
}

var diagnosticID: MessageID {
.init(domain: "PostgresRecordMacro.MacroError", id: self.message)
}

var severity: DiagnosticSeverity {
.error
}
}
80 changes: 80 additions & 0 deletions Macros/PostgresRecordMacro/ParsedType.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import SwiftSyntax

indirect enum ParsedType: CustomStringConvertible {

enum Error: Swift.Error, CustomStringConvertible {
case unknownParameterType(String)
case failedToParse(Any.Type)

var description: String {
switch self {
case let .unknownParameterType(type):
return "unknownParameterType(\(type))"
case let .failedToParse(type):
return "failedToParse(\(type))"
}
}
}

case plain(String)
case optional(of: Self)
case array(of: Self)
case dictionary(key: Self, value: Self)
case member(base: Self, `extension`: String)
case unknownGeneric(String, arguments: [Self])

public var description: String {
switch self {
case let .plain(type):
return type
case let .optional(type):
return "\(type)?"
case let .array(type):
return "[\(type)]"
case let .dictionary(key, value):
return "[\(key): \(value)]"
case let .member(base, `extension`):
return "\(base.description).\(`extension`)"
case let .unknownGeneric(name, arguments: arguments):
return "\(name)<\(arguments.map(\.description).joined(separator: ", "))>"
}
}

public init(syntax: some TypeSyntaxProtocol) throws {
if let type = syntax.as(IdentifierTypeSyntax.self) {
let name = type.name.trimmedDescription
if let genericArgumentClause = type.genericArgumentClause,
!genericArgumentClause.arguments.isEmpty {
let arguments = genericArgumentClause.arguments
switch (arguments.count, name) {
case (1, "Optional"):
self = try .optional(of: Self(syntax: arguments.first!.argument))
case (1, "Array"):
self = try .array(of: Self(syntax: arguments.first!.argument))
case (2, "Dictionary"):
let key = try Self(syntax: arguments.first!.argument)
let value = try Self(syntax: arguments.last!.argument)
self = .dictionary(key: key, value: value)
default:
let arguments = try arguments.map(\.argument).map(Self.init(syntax:))
self = .unknownGeneric(name, arguments: arguments)
}
} else {
self = .plain(name)
}
} else if let type = syntax.as(OptionalTypeSyntax.self) {
self = try .optional(of: Self(syntax: type.wrappedType))
} else if let type = syntax.as(ArrayTypeSyntax.self) {
self = try .array(of: Self(syntax: type.element))
} else if let type = syntax.as(DictionaryTypeSyntax.self) {
let key = try Self(syntax: type.key)
let value = try Self(syntax: type.value)
self = .dictionary(key: key, value: value)
} else if let type = syntax.as(MemberTypeSyntax.self) {
let kind = try Self(syntax: type.baseType)
self = .member(base: kind, extension: type.name.trimmedDescription)
} else {
throw Error.unknownParameterType(syntax.trimmed.description)
}
}
}
50 changes: 50 additions & 0 deletions Macros/PostgresRecordMacro/PostgresRecordMacroType.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import SwiftSyntax
import SwiftDiagnostics
import SwiftSyntaxMacros

public enum PostgresRecordMacroType: ExtensionMacro {
public static func expansion(
of node: AttributeSyntax,
attachedTo declaration: some DeclGroupSyntax,
providingExtensionsOf type: some TypeSyntaxProtocol,
conformingTo protocols: [TypeSyntax],
in context: some MacroExpansionContext
) throws -> [ExtensionDeclSyntax] {
if declaration.hasError { return [] }
let diagnoser = Diagnoser(context: context)

guard let structDecl = declaration.as(StructDeclSyntax.self) else {
throw MacroError.isNotStruct
}
let accessLevel = structDecl.accessLevelModifier.map { "\($0) " } ?? ""
/// Compiler won't be able to infer what function to use when doing `PostgresRow.decode()`.
let forbiddenProtocols = ["PostgresCodable", "PostgresDecodable"]
let inheritedTypes = structDecl.inheritanceClause?.inheritedTypes ?? []
for idx in inheritedTypes.indices {
let proto = inheritedTypes[idx]
let name = proto.trimmedDescription
if forbiddenProtocols.contains(name) {
diagnoser.cannotConformToProtocol(
name: name,
old: structDecl,
new: structDecl.removingInheritedType(at: idx)
)
return []
}
}

let members = structDecl.memberBlock.members
let variableDecls = members.compactMap { $0.decl.as(VariableDeclSyntax.self) }
let variables = try variableDecls.flatMap(Variable.parse(from:))

let name = structDecl.name.trimmedDescription
let initializer = variables.makePostgresRecordInit(name: name, accessLevel: accessLevel)
let postgresRecord = try ExtensionDeclSyntax("""
extension \(raw: name): PostgresRecord {
\(raw: initializer)
}
""")

return [postgresRecord]
}
}
36 changes: 36 additions & 0 deletions Macros/PostgresRecordMacro/Variable+init.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import SwiftSyntax

extension [Variable] {
func makePostgresRecordInit(name: String, accessLevel: String) -> String {
"""
\(accessLevel)init(
_from row: PostgresRow,
context: PostgresDecodingContext<some PostgresJSONDecoder>,
file: String,
line: Int
) throws {
let decoded = try row.decode(
\(makeType()),
context: context,
file: file,
line: line
)
\(makeInitializations())
}
"""
}

private func makeType() -> String {
"(\(self.map(\.type.description).joined(separator: ","))).self"
}

private func makeInitializations() -> String {
if self.count == 1 {
return " self.\(self[0].name) = decoded"
} else {
return self.enumerated().map { (idx, variable) in
" self.\(variable.name) = decoded.\(idx)"
}.joined(separator: "\n")
}
}
}
37 changes: 37 additions & 0 deletions Macros/PostgresRecordMacro/Variable.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import SwiftSyntax

struct Variable {

enum Error: Swift.Error, CustomStringConvertible {
case unsupportedPattern(String)
case typeSyntaxNotFound(name: String)

var description: String {
switch self {
case let .unsupportedPattern(pattern):
return "unsupportedPattern(\(pattern))"
case let .typeSyntaxNotFound(name):
return "typeSyntaxNotFound(name: \(name))"
}
}
}

let name: String
let type: ParsedType

static func parse(from element: VariableDeclSyntax) throws -> [Variable] {
try element.bindings.map { binding in
guard let pattern = binding.pattern.as(IdentifierPatternSyntax.self) else {
throw Error.unsupportedPattern(binding.pattern.trimmedDescription)
}
let name = pattern.identifier.trimmed.text

guard let typeSyntax = binding.typeAnnotation?.type else {
throw Error.typeSyntaxNotFound(name: name)
}
let type = try ParsedType(syntax: typeSyntax)

return Variable(name: name, type: type)
}
}
}
Loading

0 comments on commit 9e3b416

Please sign in to comment.