Skip to content

Commit

Permalink
add identity case for QueryPredicateGroup
Browse files Browse the repository at this point in the history
  • Loading branch information
5d committed Mar 11, 2024
1 parent 93518d7 commit a63d07f
Show file tree
Hide file tree
Showing 10 changed files with 67 additions and 53 deletions.
16 changes: 0 additions & 16 deletions Amplify/Categories/DataStore/Query/ModelKey.swift
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,6 @@ extension CodingKey where Self: ModelKey {
return field(stringValue).eq(value)
}

public func eq(_ value: Persistable) -> QueryPredicateOperation {
return field(stringValue).eq(value)
}

public func eq(_ value: EnumPersistable) -> QueryPredicateOperation {
return field(stringValue).eq(value)
}
Expand All @@ -84,10 +80,6 @@ extension CodingKey where Self: ModelKey {
return key.eq(value)
}

public static func == (key: Self, value: Persistable) -> QueryPredicateOperation {
return key.eq(value)
}

public static func == (key: Self, value: EnumPersistable) -> QueryPredicateOperation {
return key.eq(value)
}
Expand Down Expand Up @@ -138,10 +130,6 @@ extension CodingKey where Self: ModelKey {
return field(stringValue).ne(value)
}

public func ne(_ value: Persistable) -> QueryPredicateOperation {
return field(stringValue).ne(value)
}

public func ne(_ value: EnumPersistable) -> QueryPredicateOperation {
return field(stringValue).ne(value)
}
Expand All @@ -150,10 +138,6 @@ extension CodingKey where Self: ModelKey {
return key.ne(value)
}

public static func != (key: Self, value: Persistable) -> QueryPredicateOperation {
return key.ne(value)
}

public static func != (key: Self, value: EnumPersistable) -> QueryPredicateOperation {
return key.ne(value)
}
Expand Down
41 changes: 16 additions & 25 deletions Amplify/Categories/DataStore/Query/QueryField.swift
Original file line number Diff line number Diff line change
Expand Up @@ -35,28 +35,24 @@ public protocol QueryFieldOperation {
func between(start: Persistable, end: Persistable) -> QueryPredicateOperation
func contains(_ value: String) -> QueryPredicateOperation
func notContains(_ value: String) -> QueryPredicateOperation
func eq(_ value: Persistable) -> QueryPredicateOperation
func eq(_ value: Persistable?) -> QueryPredicateGroup
func eq(_ value: EnumPersistable) -> QueryPredicateOperation
func ge(_ value: Persistable) -> QueryPredicateOperation
func gt(_ value: Persistable) -> QueryPredicateOperation
func le(_ value: Persistable) -> QueryPredicateOperation
func lt(_ value: Persistable) -> QueryPredicateOperation
func ne(_ value: Persistable) -> QueryPredicateOperation
func ne(_ value: Persistable?) -> QueryPredicateGroup
func ne(_ value: EnumPersistable) -> QueryPredicateOperation

// MARK: - Operators

static func ~= (key: Self, value: String) -> QueryPredicateOperation
static func == (key: Self, value: Persistable) -> QueryPredicateOperation
static func == (key: Self, value: Persistable?) -> QueryPredicateGroup
static func == (key: Self, value: EnumPersistable) -> QueryPredicateOperation
static func >= (key: Self, value: Persistable) -> QueryPredicateOperation
static func > (key: Self, value: Persistable) -> QueryPredicateOperation
static func <= (key: Self, value: Persistable) -> QueryPredicateOperation
static func < (key: Self, value: Persistable) -> QueryPredicateOperation
static func != (key: Self, value: Persistable) -> QueryPredicateOperation
static func != (key: Self, value: Persistable?) -> QueryPredicateGroup
static func != (key: Self, value: EnumPersistable) -> QueryPredicateOperation
}
Expand Down Expand Up @@ -96,13 +92,15 @@ public struct QueryField: QueryFieldOperation {
}

// MARK: - eq
public func eq(_ value: Persistable) -> QueryPredicateOperation {
return QueryPredicateOperation(field: name, operator: .equals(value))
}

public func eq(_ value: Persistable?) -> QueryPredicateGroup {
return QueryPredicateOperation(field: name, operator: .attributeExists(false))
|| QueryPredicateOperation(field: name, operator: .equals(value))
if let value {
return QueryPredicateGroup(
predicate: QueryPredicateOperation(field: name, operator: .equals(value))
)
} else {
return QueryPredicateOperation(field: name, operator: .attributeExists(false))
|| QueryPredicateOperation(field: name, operator: .equals(value))
}
}

public func eq(_ value: EnumPersistable) -> QueryPredicateOperation {
Expand All @@ -113,10 +111,6 @@ public struct QueryField: QueryFieldOperation {
return key.eq(value)
}

public static func == (key: Self, value: Persistable) -> QueryPredicateOperation {
return key.eq(value)
}

public static func == (key: Self, value: EnumPersistable) -> QueryPredicateOperation {
return key.eq(value)
}
Expand Down Expand Up @@ -164,15 +158,16 @@ public struct QueryField: QueryFieldOperation {
// MARK: - ne

public func ne(_ value: Persistable?) -> QueryPredicateGroup {
return QueryPredicateOperation(field: name, operator: .attributeExists(true))
&& QueryPredicateOperation(field: name, operator: .notEqual(value))
if let value {
return QueryPredicateGroup(
predicate: QueryPredicateOperation(field: name, operator: .notEqual(value))
)
} else {
return QueryPredicateOperation(field: name, operator: .attributeExists(true))
&& QueryPredicateOperation(field: name, operator: .notEqual(value))
}
}

public func ne(_ value: Persistable) -> QueryPredicateOperation {
return QueryPredicateOperation(field: name, operator: .notEqual(value))
}


public func ne(_ value: EnumPersistable) -> QueryPredicateOperation {
return QueryPredicateOperation(field: name, operator: .notEqual(value.rawValue))
}
Expand All @@ -181,10 +176,6 @@ public struct QueryField: QueryFieldOperation {
return key.ne(value)
}

public static func != (key: Self, value: Persistable) -> QueryPredicateOperation {
return key.ne(value)
}

public static func != (key: Self, value: EnumPersistable) -> QueryPredicateOperation {
return key.ne(value)
}
Expand Down
30 changes: 30 additions & 0 deletions Amplify/Categories/DataStore/Query/QueryPredicate.swift
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,33 @@ import Foundation
/// Protocol that indicates concrete types conforming to it can be used a predicate member.
public protocol QueryPredicate: Evaluable, Encodable {}

public extension QueryPredicate {
func optimize() -> QueryPredicate {
func simplifyIdentityPredicateGroup(_ predicateGroup: QueryPredicateGroup) -> QueryPredicate {
switch predicateGroup.type {
case .id:
return predicateGroup.predicates.first!
default:
return QueryPredicateGroup(
type: predicateGroup.type,
predicates: predicateGroup.predicates.map { $0.optimize() }
)
}
}

if let predicate = self as? QueryPredicateGroup {
return simplifyIdentityPredicateGroup(predicate)
} else {
return self
}
}
}

public enum QueryPredicateGroupType: String, Encodable {
case and
case or
case not
case id
}

/// The `not` function is used to wrap a `QueryPredicate` in a `QueryPredicateGroup` of type `.not`.
Expand Down Expand Up @@ -43,6 +66,10 @@ public class QueryPredicateGroup: QueryPredicate, Encodable {
self.predicates = predicates
}

public convenience init(predicate: QueryPredicate) {
self.init(type: .id, predicates: [predicate])
}

public func and(_ predicate: QueryPredicate) -> QueryPredicateGroup {
if case .and = type {
predicates.append(predicate)
Expand Down Expand Up @@ -90,6 +117,9 @@ public class QueryPredicateGroup: QueryPredicate, Encodable {
case .not:
let predicate = predicates[0]
return !predicate.evaluate(target: target)
case .id:
let predicate = predicates[0]
return predicate.evaluate(target: target)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class AWSRESTOperationTests: OperationTestBase {
}

// TODO: Fix this test
func testGetReturnsOperation() throws {
func testGetReturnsOperation() async throws {
try setUpPlugin(endpointType: .rest)

// Use this as a semaphore to ensure the task is cleaned up before proceeding to the next test
Expand All @@ -50,7 +50,7 @@ class AWSRESTOperationTests: OperationTestBase {

XCTAssertNotNil(operation.request)

waitForExpectations(timeout: 1.00)
await fulfillment(of: [listenerWasInvoked], timeout: 2)
}

func testGetFailsWithBadAPIName() throws {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,12 @@ extension QueryPredicateGroup: GraphQLFilterConvertible {
} else {
return Fatal.preconditionFailure("Missing predicate for \(String(describing: self)) with type: \(type)")
}
case .id:
if let predicate = predicates.first {
return predicate.graphQLFilter(for: modelSchema)
} else {
return Fatal.preconditionFailure("Missing predicate for \(String(describing: self)) with type: \(type)")
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,6 @@ class GraphQLRequestAnyModelWithSyncTests: XCTestCase {
XCTAssertEqual(variables["limit"] as? Int, limit)
XCTAssertEqual(variables["nextToken"] as? String, nextToken)
XCTAssertNotNil(filter)
XCTAssertNotNil(filter["and"])
}

func testSyncQueryGraphQLRequestWithPredicateGroupFilter() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ private func translateQueryPredicate(from modelSchema: ModelSchema,

// the very first `and` is always prepended, using -1 for if statement checking
// the very first `and` is to connect `where` clause with translated QueryPredicate
translate(optimizeQueryPredicateGroup(predicate), predicateIndex: -1, groupType: .and)
translate(optimizeQueryPredicateGroup(predicate.optimize()), predicateIndex: -1, groupType: .and)
return (sql.joined(separator: "\n"), bindings)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ final class InitialSyncOperation: AsynchronousOperation {
private var syncPredicate: QueryPredicate? {
return dataStoreConfiguration.syncExpressions.first {
$0.modelSchema.name == self.modelSchema.name
}?.modelPredicate()
}?.modelPredicate().optimize()
}

private var syncPredicateString: String? {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class QueryPredicateTests: XCTestCase {
let predicate = post.draft.eq(true)

XCTAssertEqual(predicate, post.draft == true)
XCTAssertEqual(predicate, QueryPredicateOperation(field: "draft", operator: .equals(true)))
XCTAssertEqual(predicate, QueryPredicateGroup(predicate: QueryPredicateOperation(field: "draft", operator: .equals(true))))
}

/// it should create a simple `QueryPredicateGroup`
Expand All @@ -41,7 +41,9 @@ class QueryPredicateTests: XCTestCase {
let expected = QueryPredicateGroup(
type: .and,
predicates: [
QueryPredicateOperation(field: "draft", operator: .equals(true)),
QueryPredicateGroup(
predicate:QueryPredicateOperation(field: "draft", operator: .equals(true))
),
QueryPredicateGroup(type: .and, predicates: [
QueryPredicateOperation(field: "id", operator: .attributeExists(true)),
QueryPredicateOperation(field: "id", operator: .notEqual(nil))
Expand All @@ -68,7 +70,9 @@ class QueryPredicateTests: XCTestCase {
let expected = QueryPredicateGroup(
type: .and,
predicates: [
QueryPredicateOperation(field: "draft", operator: .equals(true)),
QueryPredicateGroup(
predicate: QueryPredicateOperation(field: "draft", operator: .equals(true))
),
QueryPredicateGroup(type: .and, predicates: [
QueryPredicateOperation(field: "id", operator: .attributeExists(true)),
QueryPredicateOperation(field: "id", operator: .notEqual(nil))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class ModelIdentifierTests: XCTestCase {
func testModelIdentifierDefaultIdPredicate() {
let model = Post(title: "title", content: "", createdAt: Temporal.DateTime.now())

let predicate = (model.identifier(schema: Post.schema).predicate as? QueryPredicateOperation)!
let predicate = (model.identifier(schema: Post.schema).predicate as? QueryPredicateGroup)!

XCTAssertEqual(predicate, Post.keys.id == model.id)
}
Expand All @@ -79,7 +79,7 @@ class ModelIdentifierTests: XCTestCase {

let identifier = model.identifier(schema: ModelImplicitDefaultPk.schema)

let predicate = (identifier.predicate as? QueryPredicateOperation)!
let predicate = (identifier.predicate as? QueryPredicateGroup)!

XCTAssertEqual(predicate, ModelImplicitDefaultPk.keys.id == model.id)
}
Expand All @@ -89,7 +89,7 @@ class ModelIdentifierTests: XCTestCase {

let identifier = model.identifier(schema: ModelExplicitDefaultPk.schema)

let predicate = (identifier.predicate as? QueryPredicateOperation)!
let predicate = (identifier.predicate as? QueryPredicateGroup)!

XCTAssertEqual(predicate, ModelExplicitDefaultPk.keys.id == model.id)
}
Expand All @@ -99,7 +99,7 @@ class ModelIdentifierTests: XCTestCase {

let identifier = model.identifier(schema: ModelExplicitCustomPk.schema)

let predicate = (identifier.predicate as? QueryPredicateOperation)!
let predicate = (identifier.predicate as? QueryPredicateGroup)!

XCTAssertEqual(predicate, ModelExplicitCustomPk.keys.userId == model.userId)

Expand Down

0 comments on commit a63d07f

Please sign in to comment.