-
-
Notifications
You must be signed in to change notification settings - Fork 71
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
11 changed files
with
551 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
import SwiftSyntax | ||
|
||
extension StructDeclSyntax { | ||
func removingInheritedType(at idx: InheritedTypeListSyntax.Index) -> Self { | ||
var new = self | ||
new.inheritanceClause?.inheritedTypes.remove(at: idx) | ||
/// Remove the colon after types name, in e.g. `MyTable: `, if no protocols are remaining. | ||
if new.inheritanceClause?.inheritedTypes.isEmpty == true { | ||
new.inheritanceClause = nil | ||
} | ||
return new.reformatted() | ||
} | ||
|
||
var accessLevelModifier: String? { | ||
let accessLevels: [Keyword] = [.open, .public, .package, .internal, .private, .fileprivate] | ||
for modifier in self.modifiers { | ||
guard let modifier = modifier.as(DeclModifierSyntax.self), | ||
case let .keyword(keyword) = modifier.name.tokenKind else { | ||
continue | ||
} | ||
if accessLevels.contains(keyword) { | ||
return modifier.name.trimmedDescription | ||
} | ||
} | ||
return nil | ||
} | ||
|
||
/// https://github.com/apple/swift/pull/69448 | ||
/// Remove whenever this bug-fix is live (when swift 5.9.2 is out?) | ||
func conforms(to protocolName: String) -> Bool { | ||
self.inheritanceClause?.inheritedTypes.contains { | ||
$0.type.as(IdentifierTypeSyntax.self)?.name.trimmedDescription == protocolName | ||
} == true | ||
} | ||
} | ||
|
||
extension SyntaxProtocol { | ||
/// Build a syntax node from this `Buildable` and format it with the given format. | ||
func reformatted() -> Self { | ||
return self.formatted().as(Self.self)! | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
import SwiftSyntax | ||
import SwiftSyntaxMacros | ||
import SwiftDiagnostics | ||
|
||
struct Diagnoser<Context: MacroExpansionContext> { | ||
let context: Context | ||
|
||
func cannotConformToProtocol( | ||
name: String, | ||
old: some SyntaxProtocol, | ||
new: some SyntaxProtocol | ||
) { | ||
let diagnosis = Diagnosis.cannotConformToProtocol(name) | ||
context.diagnose(Diagnostic( | ||
node: old, | ||
position: old.position, | ||
message: diagnosis.diagnosticMessage, | ||
highlights: nil, | ||
notes: [], | ||
fixIt: .replace( | ||
message: diagnosis.fixItMessage, | ||
oldNode: old, | ||
newNode: new | ||
) | ||
)) | ||
} | ||
} | ||
|
||
private enum Diagnosis: Error { | ||
case cannotConformToProtocol(String) | ||
|
||
private struct _DiagnosticMessage: DiagnosticMessage { | ||
let parent: Diagnosis | ||
|
||
var message: String { | ||
switch parent { | ||
case let .cannotConformToProtocol(proto): | ||
return "Simultaneous conformance to '\(proto)' is not supported" | ||
} | ||
} | ||
|
||
var diagnosticID: SwiftDiagnostics.MessageID { | ||
.init(domain: "\(Self.self)", id: self.message) | ||
} | ||
|
||
var severity: SwiftDiagnostics.DiagnosticSeverity { | ||
.error | ||
} | ||
} | ||
|
||
private struct _FixItMessage: FixItMessage { | ||
let parent: Diagnosis | ||
|
||
var message: String { | ||
switch parent { | ||
case let .cannotConformToProtocol(proto): | ||
return "Remove conformance to '\(proto)'" | ||
} | ||
} | ||
|
||
var fixItID: SwiftDiagnostics.MessageID { | ||
.init(domain: "\(Self.self)", id: self.message) | ||
} | ||
} | ||
|
||
var diagnosticMessage: any DiagnosticMessage { | ||
_DiagnosticMessage(parent: self) | ||
} | ||
|
||
var fixItMessage: any FixItMessage { | ||
_FixItMessage(parent: self) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
import SwiftSyntaxMacros | ||
import SwiftCompilerPlugin | ||
|
||
@main | ||
struct PostgresRecordMacroEntryPoint: CompilerPlugin { | ||
static let macros: [String: any Macro.Type] = [ | ||
"PostgresRecord": PostgresRecordMacroType.self | ||
] | ||
|
||
let providingMacros: [any Macro.Type] = macros.map(\.value) | ||
|
||
init() { } | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
import SwiftDiagnostics | ||
|
||
enum MacroError: Error { | ||
case isNotStruct | ||
} | ||
|
||
extension MacroError: DiagnosticMessage { | ||
var message: String { | ||
switch self { | ||
case .isNotStruct: | ||
return "Only 'struct's are supported" | ||
} | ||
} | ||
|
||
var diagnosticID: MessageID { | ||
.init(domain: "PostgresRecordMacro.MacroError", id: self.message) | ||
} | ||
|
||
var severity: DiagnosticSeverity { | ||
.error | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
import SwiftSyntax | ||
|
||
indirect enum ParsedType: CustomStringConvertible { | ||
|
||
enum Error: Swift.Error, CustomStringConvertible { | ||
case unknownParameterType(String) | ||
case failedToParse(Any.Type) | ||
|
||
var description: String { | ||
switch self { | ||
case let .unknownParameterType(type): | ||
return "unknownParameterType(\(type))" | ||
case let .failedToParse(type): | ||
return "failedToParse(\(type))" | ||
} | ||
} | ||
} | ||
|
||
case plain(String) | ||
case optional(of: Self) | ||
case array(of: Self) | ||
case dictionary(key: Self, value: Self) | ||
case member(base: Self, `extension`: String) | ||
case unknownGeneric(String, arguments: [Self]) | ||
|
||
public var description: String { | ||
switch self { | ||
case let .plain(type): | ||
return type | ||
case let .optional(type): | ||
return "\(type)?" | ||
case let .array(type): | ||
return "[\(type)]" | ||
case let .dictionary(key, value): | ||
return "[\(key): \(value)]" | ||
case let .member(base, `extension`): | ||
return "\(base.description).\(`extension`)" | ||
case let .unknownGeneric(name, arguments: arguments): | ||
return "\(name)<\(arguments.map(\.description).joined(separator: ", "))>" | ||
} | ||
} | ||
|
||
public init(syntax: some TypeSyntaxProtocol) throws { | ||
if let type = syntax.as(IdentifierTypeSyntax.self) { | ||
let name = type.name.trimmedDescription | ||
if let genericArgumentClause = type.genericArgumentClause, | ||
!genericArgumentClause.arguments.isEmpty { | ||
let arguments = genericArgumentClause.arguments | ||
switch (arguments.count, name) { | ||
case (1, "Optional"): | ||
self = try .optional(of: Self(syntax: arguments.first!.argument)) | ||
case (1, "Array"): | ||
self = try .array(of: Self(syntax: arguments.first!.argument)) | ||
case (2, "Dictionary"): | ||
let key = try Self(syntax: arguments.first!.argument) | ||
let value = try Self(syntax: arguments.last!.argument) | ||
self = .dictionary(key: key, value: value) | ||
default: | ||
let arguments = try arguments.map(\.argument).map(Self.init(syntax:)) | ||
self = .unknownGeneric(name, arguments: arguments) | ||
} | ||
} else { | ||
self = .plain(name) | ||
} | ||
} else if let type = syntax.as(OptionalTypeSyntax.self) { | ||
self = try .optional(of: Self(syntax: type.wrappedType)) | ||
} else if let type = syntax.as(ArrayTypeSyntax.self) { | ||
self = try .array(of: Self(syntax: type.element)) | ||
} else if let type = syntax.as(DictionaryTypeSyntax.self) { | ||
let key = try Self(syntax: type.key) | ||
let value = try Self(syntax: type.value) | ||
self = .dictionary(key: key, value: value) | ||
} else if let type = syntax.as(MemberTypeSyntax.self) { | ||
let kind = try Self(syntax: type.baseType) | ||
self = .member(base: kind, extension: type.name.trimmedDescription) | ||
} else { | ||
throw Error.unknownParameterType(syntax.trimmed.description) | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
import SwiftSyntax | ||
import SwiftDiagnostics | ||
import SwiftSyntaxMacros | ||
|
||
public enum PostgresRecordMacroType: ExtensionMacro { | ||
public static func expansion( | ||
of node: AttributeSyntax, | ||
attachedTo declaration: some DeclGroupSyntax, | ||
providingExtensionsOf type: some TypeSyntaxProtocol, | ||
conformingTo protocols: [TypeSyntax], | ||
in context: some MacroExpansionContext | ||
) throws -> [ExtensionDeclSyntax] { | ||
if declaration.hasError { return [] } | ||
let diagnoser = Diagnoser(context: context) | ||
|
||
guard let structDecl = declaration.as(StructDeclSyntax.self) else { | ||
throw MacroError.isNotStruct | ||
} | ||
let accessLevel = structDecl.accessLevelModifier.map { "\($0) " } ?? "" | ||
/// Compiler won't be able to infer what function to use when doing `PostgresRow.decode()`. | ||
let forbiddenProtocols = ["PostgresCodable", "PostgresDecodable"] | ||
let inheritedTypes = structDecl.inheritanceClause?.inheritedTypes ?? [] | ||
for idx in inheritedTypes.indices { | ||
let proto = inheritedTypes[idx] | ||
let name = proto.trimmedDescription | ||
if forbiddenProtocols.contains(name) { | ||
diagnoser.cannotConformToProtocol( | ||
name: name, | ||
old: structDecl, | ||
new: structDecl.removingInheritedType(at: idx) | ||
) | ||
return [] | ||
} | ||
} | ||
|
||
let members = structDecl.memberBlock.members | ||
let variableDecls = members.compactMap { $0.decl.as(VariableDeclSyntax.self) } | ||
let variables = try variableDecls.flatMap(Variable.parse(from:)) | ||
|
||
let name = structDecl.name.trimmedDescription | ||
let initializer = variables.makePostgresRecordInit(name: name, accessLevel: accessLevel) | ||
let postgresRecord = try ExtensionDeclSyntax(""" | ||
extension \(raw: name): PostgresRecord { | ||
\(raw: initializer) | ||
} | ||
""") | ||
|
||
return [postgresRecord] | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import SwiftSyntax | ||
|
||
extension [Variable] { | ||
func makePostgresRecordInit(name: String, accessLevel: String) -> String { | ||
""" | ||
\(accessLevel)init( | ||
_from row: PostgresRow, | ||
context: PostgresDecodingContext<some PostgresJSONDecoder>, | ||
file: String, | ||
line: Int | ||
) throws { | ||
let decoded = try row.decode( | ||
\(makeType()), | ||
context: context, | ||
file: file, | ||
line: line | ||
) | ||
\(makeInitializations()) | ||
} | ||
""" | ||
} | ||
|
||
private func makeType() -> String { | ||
"(\(self.map(\.type.description).joined(separator: ","))).self" | ||
} | ||
|
||
private func makeInitializations() -> String { | ||
if self.count == 1 { | ||
return " self.\(self[0].name) = decoded" | ||
} else { | ||
return self.enumerated().map { (idx, variable) in | ||
" self.\(variable.name) = decoded.\(idx)" | ||
}.joined(separator: "\n") | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import SwiftSyntax | ||
|
||
struct Variable { | ||
|
||
enum Error: Swift.Error, CustomStringConvertible { | ||
case unsupportedPattern(String) | ||
case typeSyntaxNotFound(name: String) | ||
|
||
var description: String { | ||
switch self { | ||
case let .unsupportedPattern(pattern): | ||
return "unsupportedPattern(\(pattern))" | ||
case let .typeSyntaxNotFound(name): | ||
return "typeSyntaxNotFound(name: \(name))" | ||
} | ||
} | ||
} | ||
|
||
let name: String | ||
let type: ParsedType | ||
|
||
static func parse(from element: VariableDeclSyntax) throws -> [Variable] { | ||
try element.bindings.map { binding in | ||
guard let pattern = binding.pattern.as(IdentifierPatternSyntax.self) else { | ||
throw Error.unsupportedPattern(binding.pattern.trimmedDescription) | ||
} | ||
let name = pattern.identifier.trimmed.text | ||
|
||
guard let typeSyntax = binding.typeAnnotation?.type else { | ||
throw Error.typeSyntaxNotFound(name: name) | ||
} | ||
let type = try ParsedType(syntax: typeSyntax) | ||
|
||
return Variable(name: name, type: type) | ||
} | ||
} | ||
} |
Oops, something went wrong.