Skip to content

Commit

Permalink
Merge pull request #61 from squareup/skorulis/getter-named
Browse files Browse the repository at this point in the history
Add parameter to name-getters to allow customisation
  • Loading branch information
skorulis-ap authored Aug 1, 2023
2 parents 806aa6e + 05ffe7e commit d2a24e9
Show file tree
Hide file tree
Showing 10 changed files with 204 additions and 60 deletions.
4 changes: 2 additions & 2 deletions Example/KnitExample/KnitExampleAssembly.swift
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ final class KnitExampleAssembly: Assembly {
func assemble(container: Container) {
container.addBehavior(ServiceCollector())

// @knit named-getter
// @knit getter-named
container.autoregister(ExampleService.self, initializer: ExampleService.init)

// @knit named-getter
// @knit getter-named("example")
container.register(ExampleArgumentService.self) { (_, arg: String) in
ExampleArgumentService.init(string: arg)
}
Expand Down
11 changes: 10 additions & 1 deletion Sources/KnitCodeGen/AssemblyParsing.swift
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ func parseSyntaxTree(
throw AssemblyParsingError.missingModuleName
}

errorsToPrint.append(contentsOf: assemblyFileVisitor.assemblyErrors)
errorsToPrint.append(contentsOf: assemblyFileVisitor.registrationErrors)

return Configuration(
Expand All @@ -51,6 +52,8 @@ private class AssemblyFileVisitor: SyntaxVisitor {

private var classDeclVisitor: ClassDeclVisitor?

private(set) var assemblyErrors: [Error] = []

var registrations: [Registration] {
return classDeclVisitor?.registrations ?? []
}
Expand Down Expand Up @@ -85,7 +88,13 @@ private class AssemblyFileVisitor: SyntaxVisitor {
// Only the first class declaration should be visited
return .skipChildren
}
let directives = KnitDirectives.parse(leadingTrivia: node.leadingTrivia)
var directives: KnitDirectives = .empty
do {
directives = try KnitDirectives.parse(leadingTrivia: node.leadingTrivia)
} catch {
assemblyErrors.append(error)
}

moduleName = node.moduleNameForAssembly
classDeclVisitor = ClassDeclVisitor(viewMode: .fixedUp, directives: directives)
classDeclVisitor?.walk(node)
Expand Down
11 changes: 9 additions & 2 deletions Sources/KnitCodeGen/FunctionCallRegistrationParsing.swift
Original file line number Diff line number Diff line change
Expand Up @@ -151,15 +151,22 @@ private func makeRegistrationFor(

let registrationText = firstParam.base!.withoutTrivia().description
let name = try getName(arguments: arguments)
let directives = KnitDirectives.parse(leadingTrivia: leadingTrivia)
let directives = try KnitDirectives.parse(leadingTrivia: leadingTrivia)

var getterConfig: Set<GetterConfig> = GetterConfig.default
if !directives.getterConfig.isEmpty {
getterConfig = directives.getterConfig
} else if !defaultDirectives.getterConfig.isEmpty {
getterConfig = defaultDirectives.getterConfig
}

return Registration(
service: registrationText,
name: name,
accessLevel: directives.accessLevel ?? defaultDirectives.accessLevel ?? .default,
arguments: registrationArguments,
isForwarded: isForwarded,
getterConfig: directives.getterConfig ?? defaultDirectives.getterConfig ?? .default
getterConfig: getterConfig
)
}

Expand Down
94 changes: 70 additions & 24 deletions Sources/KnitCodeGen/KnitDirectives.swift
Original file line number Diff line number Diff line change
@@ -1,51 +1,97 @@
// Copyright © Square, Inc. All rights reserved.

import Foundation
import SwiftSyntax

struct KnitDirectives: Codable {
struct KnitDirectives: Codable, Equatable {
let accessLevel: AccessLevel?
let getterConfig: GetterConfig?
let getterConfig: Set<GetterConfig>

static func parse(leadingTrivia: Trivia?) -> KnitDirectives {
guard let leadingTriviaText = leadingTrivia?.description, leadingTriviaText.contains("@knit") else {
static func parse(leadingTrivia: Trivia?) throws -> KnitDirectives {
guard let leadingTriviaText = leadingTrivia?.description else {
return .empty
}
let accessLevel: AccessLevel? = AccessLevel.allCases.first { leadingTriviaText.contains($0.rawValue) }
var tokens = leadingTriviaText
.components(separatedBy: .whitespacesAndNewlines)
.filter { !$0.isEmpty && $0 != "//" }
guard tokens.first == "@knit" else {
return .empty
}
tokens = Array(tokens.dropFirst())

let identifiedGetterOnly = leadingTriviaText.contains("getter-named")
let callAsFuncOnly = leadingTriviaText.contains("getter-callAsFunction")
var accessLevel: AccessLevel?
var getterConfigs: Set<GetterConfig> = []

let getterConfig: GetterConfig?
switch (identifiedGetterOnly, callAsFuncOnly) {
case (false, false):
getterConfig = nil
case (true, false):
getterConfig = .identifiedGetter
case (false, true):
getterConfig = .callAsFunction
case (true, true):
getterConfig = .both
for token in tokens {
let parsed = try parse(token: token)
if let level = parsed.accessLevel {
accessLevel = level
}
if let getter = parsed.getterConfig {
getterConfigs.insert(getter)
}
}

return KnitDirectives(accessLevel: accessLevel, getterConfig: getterConfig)
return KnitDirectives(accessLevel: accessLevel, getterConfig: getterConfigs)
}

static func parse(token: String) throws -> (accessLevel: AccessLevel?, getterConfig: GetterConfig?) {
if let accessLevel = AccessLevel(rawValue: token) {
return (accessLevel, nil)
}
if token == "getter-callAsFunction" {
return (nil, .callAsFunction)
}
if let nameMatch = getterNamedRegex.firstMatch(in: token, range: NSMakeRange(0, token.count)) {
if nameMatch.numberOfRanges >= 2, nameMatch.range(at: 1).location != NSNotFound {
var range = nameMatch.range(at: 1)
range = NSRange(location: range.location + 2, length: range.length - 4)
let name = (token as NSString).substring(with: range)
return (nil, .identifiedGetter(name))
}
return (nil, .identifiedGetter(nil))
}

throw Error.unexpectedToken(token: token)
}

static var empty: KnitDirectives {
return .init(accessLevel: nil, getterConfig: nil)
return .init(accessLevel: nil, getterConfig: [])
}

private static let getterNamedRegex = try! NSRegularExpression(pattern: "getter-named(\\(\"\\w*\"\\))?")
}

extension KnitDirectives {
enum Error: LocalizedError {
case unexpectedToken(token: String)

public enum GetterConfig: Codable, CaseIterable {
var errorDescription: String? {
switch self {
case let .unexpectedToken(token):
return "Unexpected knit comment rule \(token)"
}
}
}
}

public enum GetterConfig: Codable, Equatable, Hashable {
/// Only the `callAsFunction()` accessor is generated.
case callAsFunction
/// Only the identified getter is generated.
case identifiedGetter
/// Both the identified getter and the `callAsFunction()` accessors are generated.
case both
case identifiedGetter(_ name: String?)

/// Centralized control of the default behavior.
public static var `default`: GetterConfig = .callAsFunction
public static var `default`: Set<GetterConfig> = [.callAsFunction]

public static var both: Set<GetterConfig> = [.callAsFunction, .identifiedGetter(nil)]

public var isNamed: Bool {
switch self {
case .identifiedGetter: return true
default: return false
}
}
}

public enum AccessLevel: String, CaseIterable, Codable {
Expand Down
4 changes: 2 additions & 2 deletions Sources/KnitCodeGen/Registration.swift
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@ public struct Registration: Equatable, Codable {
public var isForwarded: Bool

/// This registration's getter setting.
public var getterConfig: GetterConfig
public var getterConfig: Set<GetterConfig>

public init(
service: String,
name: String? = nil,
accessLevel: AccessLevel,
arguments: [Argument] = [],
isForwarded: Bool = false,
getterConfig: GetterConfig = .default
getterConfig: Set<GetterConfig> = GetterConfig.default
) {
self.service = service
self.name = name
Expand Down
34 changes: 18 additions & 16 deletions Sources/KnitCodeGen/TypeSafetySourceFile.swift
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,7 @@ public enum TypeSafetySourceFile {
// Exclude hidden registrations always
$0.accessLevel != .hidden
}
let identifiedGetterRegistrations = visibleRegistrations.filter {
$0.getterConfig != .callAsFunction && $0.name == nil
}
let callAsFunctionRegistrations = visibleRegistrations.filter {
// Don't include generated `callAsFunction` when only generating identified getter
$0.getterConfig != .identifiedGetter && $0.name == nil
}
let unnamedRegistrations = visibleRegistrations.filter { $0.name == nil }
let namedGroups = NamedRegistrationGroup.make(from: visibleRegistrations)
return SourceFileSyntax(leadingTrivia: TriviaProvider.headerTrivia) {
for importItem in imports {
Expand All @@ -33,18 +27,20 @@ public enum TypeSafetySourceFile {
extension \(extensionTarget)
""") {

for registration in identifiedGetterRegistrations {
makeResolver(registration: registration, identifiedGetter: true)
}
for registration in callAsFunctionRegistrations {
makeResolver(registration: registration)
for registration in unnamedRegistrations {
if registration.getterConfig.contains(.callAsFunction) {
makeResolver(registration: registration, getterType: .callAsFunction)
}
if let namedGetter = registration.getterConfig.first(where: { $0.isNamed }) {
makeResolver(registration: registration, getterType: namedGetter)
}
}
for namedGroup in namedGroups {
let firstGetterConfig = namedGroup.registrations[0].getterConfig
let firstGetterConfig = namedGroup.registrations[0].getterConfig.first ?? .callAsFunction
makeResolver(
registration: namedGroup.registrations[0],
enumName: "\(assemblyName).\(namedGroup.enumName)",
identifiedGetter: firstGetterConfig != .callAsFunction
getterType: firstGetterConfig
)
}
}
Expand All @@ -58,15 +54,21 @@ public enum TypeSafetySourceFile {
static func makeResolver(
registration: Registration,
enumName: String? = nil,
identifiedGetter: Bool = false
getterType: GetterConfig = .callAsFunction
) -> FunctionDeclSyntax {
let modifier = registration.accessLevel == .public ? "public " : ""
let nameInput = enumName.map { "name: \($0)" }
let nameUsage = enumName != nil ? "name: name.rawValue" : nil
let (argInput, argUsage) = argumentString(registration: registration)
let inputs = [nameInput, argInput].compactMap { $0 }.joined(separator: ", ")
let usages = ["\(registration.service).self", nameUsage, argUsage].compactMap { $0 }.joined(separator: ", ")
let funcName = identifiedGetter ? TypeNamer.computedIdentifierName(type: registration.service) : "callAsFunction"
let funcName: String
switch getterType {
case .callAsFunction:
funcName = "callAsFunction"
case let .identifiedGetter(name):
funcName = name ?? TypeNamer.computedIdentifierName(type: registration.service)
}

return FunctionDeclSyntax("\(modifier)func \(funcName)(\(inputs)) -> \(registration.service)") {
ForcedValueExprSyntax("self.resolve(\(raw: usages))!")
Expand Down
4 changes: 2 additions & 2 deletions Tests/KnitCodeGenTests/AssemblyParsingTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@ final class AssemblyParsingTests: XCTestCase {
XCTAssertEqual(
config.registrations,
[
.init(service: "A", accessLevel: .public, getterConfig: .identifiedGetter),
.init(service: "B", accessLevel: .internal, getterConfig: .callAsFunction)
.init(service: "A", accessLevel: .public, getterConfig: [.identifiedGetter(nil)]),
.init(service: "B", accessLevel: .internal, getterConfig: [.callAsFunction])
]
)
}
Expand Down
80 changes: 80 additions & 0 deletions Tests/KnitCodeGenTests/KnitDirectivesTests.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
// Created by Alexander skorulis on 28/7/2023.

@testable import KnitCodeGen
import SwiftSyntax
import XCTest

final class KnitDirectivesTests: XCTestCase {

func testAccessLevel() throws {
XCTAssertEqual(
try parse(" @knit public"),
.init(accessLevel: .public, getterConfig: [])
)

XCTAssertEqual(
try parse("@knit internal"),
.init(accessLevel: .internal, getterConfig: [])
)

XCTAssertEqual(
try parse("@knit hidden"),
.init(accessLevel: .hidden, getterConfig: [])
)
}

func testKnitPrefix() {
XCTAssertEqual(
try parse("// @knit public"),
.init(accessLevel: .public, getterConfig: [])
)

XCTAssertEqual(
try parse("knit public"),
.empty
)

XCTAssertEqual(
try parse("public @knit"),
.empty
)

XCTAssertEqual(
try parse("informational comment"),
.empty
)
}

func testGetterConfig() {
XCTAssertEqual(
try parse("// @knit getter-named"),
.init(accessLevel: nil, getterConfig: [.identifiedGetter(nil)])
)

XCTAssertEqual(
try parse("// @knit getter-named(\"customName\")"),
.init(accessLevel: nil, getterConfig: [.identifiedGetter("customName")])
)

XCTAssertEqual(
try parse("// @knit getter-callAsFunction"),
.init(accessLevel: nil, getterConfig: [.callAsFunction])
)

XCTAssertEqual(
try parse("// @knit getter-callAsFunction getter-named"),
.init(accessLevel: nil, getterConfig: [.identifiedGetter(nil), .callAsFunction])
)

XCTAssertEqual(
try parse("// @knit getter-callAsFunction getter-named"),
.init(accessLevel: nil, getterConfig: [.identifiedGetter(nil), .callAsFunction])
)
}

private func parse(_ comment: String) throws -> KnitDirectives {
let trivia = Trivia(pieces: [.lineComment(comment)])
return try KnitDirectives.parse(leadingTrivia: trivia)
}

}
6 changes: 3 additions & 3 deletions Tests/KnitCodeGenTests/RegistrationParsingTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ final class RegistrationParsingTests: XCTestCase {
container.register(A.self) { }
""",
registrations: [
.init(service: "A", accessLevel: .public, getterConfig: .identifiedGetter)
.init(service: "A", accessLevel: .public, getterConfig: [.identifiedGetter(nil)])
]
)

Expand All @@ -103,7 +103,7 @@ final class RegistrationParsingTests: XCTestCase {
container.register(A.self) { }
""",
registrations: [
.init(service: "A", accessLevel: .public, getterConfig: .callAsFunction)
.init(service: "A", accessLevel: .public, getterConfig: [.callAsFunction])
]
)

Expand All @@ -113,7 +113,7 @@ final class RegistrationParsingTests: XCTestCase {
container.register(A.self) { }
""",
registrations: [
.init(service: "A", accessLevel: .public, getterConfig: .both)
.init(service: "A", accessLevel: .public, getterConfig: GetterConfig.both)
]
)
}
Expand Down
Loading

0 comments on commit d2a24e9

Please sign in to comment.