Skip to content

Commit 93518d7

Browse files
committed
update the eq and ne api signatures
1 parent d07f4cd commit 93518d7

File tree

6 files changed

+125
-22
lines changed

6 files changed

+125
-22
lines changed

Amplify/Categories/DataStore/Query/ModelKey.swift

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,15 +68,23 @@ extension CodingKey where Self: ModelKey {
6868

6969
// MARK: - eq
7070

71-
public func eq(_ value: Persistable?) -> QueryPredicateOperation {
71+
public func eq(_ value: Persistable?) -> QueryPredicateGroup {
72+
return field(stringValue).eq(value)
73+
}
74+
75+
public func eq(_ value: Persistable) -> QueryPredicateOperation {
7276
return field(stringValue).eq(value)
7377
}
7478

7579
public func eq(_ value: EnumPersistable) -> QueryPredicateOperation {
7680
return field(stringValue).eq(value)
7781
}
7882

79-
public static func == (key: Self, value: Persistable?) -> QueryPredicateOperation {
83+
public static func == (key: Self, value: Persistable?) -> QueryPredicateGroup {
84+
return key.eq(value)
85+
}
86+
87+
public static func == (key: Self, value: Persistable) -> QueryPredicateOperation {
8088
return key.eq(value)
8189
}
8290

@@ -126,15 +134,23 @@ extension CodingKey where Self: ModelKey {
126134

127135
// MARK: - ne
128136

129-
public func ne(_ value: Persistable?) -> QueryPredicateOperation {
137+
public func ne(_ value: Persistable?) -> QueryPredicateGroup {
138+
return field(stringValue).ne(value)
139+
}
140+
141+
public func ne(_ value: Persistable) -> QueryPredicateOperation {
130142
return field(stringValue).ne(value)
131143
}
132144

133145
public func ne(_ value: EnumPersistable) -> QueryPredicateOperation {
134146
return field(stringValue).ne(value)
135147
}
136148

137-
public static func != (key: Self, value: Persistable?) -> QueryPredicateOperation {
149+
public static func != (key: Self, value: Persistable?) -> QueryPredicateGroup {
150+
return key.ne(value)
151+
}
152+
153+
public static func != (key: Self, value: Persistable) -> QueryPredicateOperation {
138154
return key.ne(value)
139155
}
140156

Amplify/Categories/DataStore/Query/QueryField.swift

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,25 +35,29 @@ public protocol QueryFieldOperation {
3535
func between(start: Persistable, end: Persistable) -> QueryPredicateOperation
3636
func contains(_ value: String) -> QueryPredicateOperation
3737
func notContains(_ value: String) -> QueryPredicateOperation
38-
func eq(_ value: Persistable?) -> QueryPredicateOperation
38+
func eq(_ value: Persistable) -> QueryPredicateOperation
39+
func eq(_ value: Persistable?) -> QueryPredicateGroup
3940
func eq(_ value: EnumPersistable) -> QueryPredicateOperation
4041
func ge(_ value: Persistable) -> QueryPredicateOperation
4142
func gt(_ value: Persistable) -> QueryPredicateOperation
4243
func le(_ value: Persistable) -> QueryPredicateOperation
4344
func lt(_ value: Persistable) -> QueryPredicateOperation
44-
func ne(_ value: Persistable?) -> QueryPredicateOperation
45+
func ne(_ value: Persistable) -> QueryPredicateOperation
46+
func ne(_ value: Persistable?) -> QueryPredicateGroup
4547
func ne(_ value: EnumPersistable) -> QueryPredicateOperation
4648

4749
// MARK: - Operators
4850

4951
static func ~= (key: Self, value: String) -> QueryPredicateOperation
50-
static func == (key: Self, value: Persistable?) -> QueryPredicateOperation
52+
static func == (key: Self, value: Persistable) -> QueryPredicateOperation
53+
static func == (key: Self, value: Persistable?) -> QueryPredicateGroup
5154
static func == (key: Self, value: EnumPersistable) -> QueryPredicateOperation
5255
static func >= (key: Self, value: Persistable) -> QueryPredicateOperation
5356
static func > (key: Self, value: Persistable) -> QueryPredicateOperation
5457
static func <= (key: Self, value: Persistable) -> QueryPredicateOperation
5558
static func < (key: Self, value: Persistable) -> QueryPredicateOperation
56-
static func != (key: Self, value: Persistable?) -> QueryPredicateOperation
59+
static func != (key: Self, value: Persistable) -> QueryPredicateOperation
60+
static func != (key: Self, value: Persistable?) -> QueryPredicateGroup
5761
static func != (key: Self, value: EnumPersistable) -> QueryPredicateOperation
5862
}
5963

@@ -92,16 +96,24 @@ public struct QueryField: QueryFieldOperation {
9296
}
9397

9498
// MARK: - eq
95-
96-
public func eq(_ value: Persistable?) -> QueryPredicateOperation {
99+
public func eq(_ value: Persistable) -> QueryPredicateOperation {
97100
return QueryPredicateOperation(field: name, operator: .equals(value))
98101
}
99102

103+
public func eq(_ value: Persistable?) -> QueryPredicateGroup {
104+
return QueryPredicateOperation(field: name, operator: .attributeExists(false))
105+
|| QueryPredicateOperation(field: name, operator: .equals(value))
106+
}
107+
100108
public func eq(_ value: EnumPersistable) -> QueryPredicateOperation {
101109
return QueryPredicateOperation(field: name, operator: .equals(value.rawValue))
102110
}
103111

104-
public static func == (key: Self, value: Persistable?) -> QueryPredicateOperation {
112+
public static func == (key: Self, value: Persistable?) -> QueryPredicateGroup {
113+
return key.eq(value)
114+
}
115+
116+
public static func == (key: Self, value: Persistable) -> QueryPredicateOperation {
105117
return key.eq(value)
106118
}
107119

@@ -151,15 +163,25 @@ public struct QueryField: QueryFieldOperation {
151163

152164
// MARK: - ne
153165

154-
public func ne(_ value: Persistable?) -> QueryPredicateOperation {
166+
public func ne(_ value: Persistable?) -> QueryPredicateGroup {
167+
return QueryPredicateOperation(field: name, operator: .attributeExists(true))
168+
&& QueryPredicateOperation(field: name, operator: .notEqual(value))
169+
}
170+
171+
public func ne(_ value: Persistable) -> QueryPredicateOperation {
155172
return QueryPredicateOperation(field: name, operator: .notEqual(value))
156173
}
157174

175+
158176
public func ne(_ value: EnumPersistable) -> QueryPredicateOperation {
159177
return QueryPredicateOperation(field: name, operator: .notEqual(value.rawValue))
160178
}
161179

162-
public static func != (key: Self, value: Persistable?) -> QueryPredicateOperation {
180+
public static func != (key: Self, value: Persistable?) -> QueryPredicateGroup {
181+
return key.ne(value)
182+
}
183+
184+
public static func != (key: Self, value: Persistable) -> QueryPredicateOperation {
163185
return key.ne(value)
164186
}
165187

Amplify/Categories/DataStore/Query/QueryOperator+Equatable.swift

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ extension QueryOperator: Equatable {
2424
case let (.between(oneStart, oneEnd), .between(otherStart, otherEnd)):
2525
return PersistableHelper.isEqual(oneStart, otherStart)
2626
&& PersistableHelper.isEqual(oneEnd, otherEnd)
27+
case let (.attributeExists(lhs), .attributeExists(rhs)):
28+
return lhs == rhs
2729
default:
2830
return false
2931
}

AmplifyPlugins/DataStore/Sources/AWSDataStorePlugin/Storage/SQLite/SQLStatement+Condition.swift

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,66 @@ private func translateQueryPredicate(from modelSchema: ModelSchema,
7676
return operation.field.quoted()
7777
}
7878

79+
func optimizeQueryPredicateGroup(_ predicate: QueryPredicate) -> QueryPredicate {
80+
func rewritePredicate(_ predicate: QueryPredicate) -> QueryPredicate {
81+
if let operation = predicate as? QueryPredicateOperation {
82+
switch operation.operator {
83+
case .attributeExists(let bool):
84+
return QueryPredicateOperation(
85+
field: operation.field,
86+
operator: bool ? .notEqual(nil) : .equals(nil)
87+
)
88+
default:
89+
return operation
90+
}
91+
} else if let group = predicate as? QueryPredicateGroup {
92+
return optimizeQueryPredicateGroup(group)
93+
}
94+
95+
return predicate
96+
}
97+
98+
func removeDuplicatePredicate(_ predicates: [QueryPredicate]) -> [QueryPredicate] {
99+
var result = [QueryPredicate]()
100+
for predicate in predicates {
101+
let hasSameExpression = result.reduce(false) {
102+
if $0 { return $0 }
103+
switch ($1, predicate) {
104+
case let (lhs as QueryPredicateOperation, rhs as QueryPredicateOperation):
105+
return lhs == rhs
106+
case let (lhs as QueryPredicateGroup, rhs as QueryPredicateGroup):
107+
return lhs == rhs
108+
default:
109+
return false
110+
}
111+
}
112+
113+
if !hasSameExpression {
114+
result.append(predicate)
115+
}
116+
}
117+
return result
118+
}
119+
120+
switch predicate {
121+
case let predicate as QueryPredicateGroup:
122+
let optimizedPredicates = removeDuplicatePredicate(predicate.predicates.reduce([]) {
123+
$0 + [rewritePredicate($1)]
124+
})
125+
126+
if optimizedPredicates.count == 1 {
127+
return optimizedPredicates.first!
128+
} else {
129+
return QueryPredicateGroup(type: predicate.type, predicates: optimizedPredicates)
130+
}
131+
default:
132+
return predicate
133+
}
134+
}
135+
79136
// the very first `and` is always prepended, using -1 for if statement checking
80137
// the very first `and` is to connect `where` clause with translated QueryPredicate
81-
translate(predicate, predicateIndex: -1, groupType: .and)
138+
translate(optimizeQueryPredicateGroup(predicate), predicateIndex: -1, groupType: .and)
82139
return (sql.joined(separator: "\n"), bindings)
83140
}
84141

AmplifyPlugins/DataStore/Tests/AWSDataStorePluginTests/Core/QueryPredicateTests.swift

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,10 @@ class QueryPredicateTests: XCTestCase {
4242
type: .and,
4343
predicates: [
4444
QueryPredicateOperation(field: "draft", operator: .equals(true)),
45-
QueryPredicateOperation(field: "id", operator: .notEqual(nil))
45+
QueryPredicateGroup(type: .and, predicates: [
46+
QueryPredicateOperation(field: "id", operator: .attributeExists(true)),
47+
QueryPredicateOperation(field: "id", operator: .notEqual(nil))
48+
])
4649
]
4750
)
4851

@@ -66,7 +69,10 @@ class QueryPredicateTests: XCTestCase {
6669
type: .and,
6770
predicates: [
6871
QueryPredicateOperation(field: "draft", operator: .equals(true)),
69-
QueryPredicateOperation(field: "id", operator: .notEqual(nil)),
72+
QueryPredicateGroup(type: .and, predicates: [
73+
QueryPredicateOperation(field: "id", operator: .attributeExists(true)),
74+
QueryPredicateOperation(field: "id", operator: .notEqual(nil))
75+
]),
7076
QueryPredicateGroup(
7177
type: .or,
7278
predicates: [

AmplifyPlugins/DataStore/Tests/AWSDataStorePluginTests/Sync/Support/MutationEventQueryTests.swift

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ class MutationEventQueryTests: BaseDataStoreTests {
3232
wait(for: [querySuccess], timeout: 1)
3333
}
3434

35-
func testQueryPendingMutationEvent() {
35+
func testQueryPendingMutationEvent() async {
3636
let mutationEvent = generateRandomMutationEvent()
3737

3838
let querySuccess = expectation(description: "query for pending mutation events")
@@ -55,10 +55,10 @@ class MutationEventQueryTests: BaseDataStoreTests {
5555
case .failure(let error): XCTFail("\(error)")
5656
}
5757
}
58-
wait(for: [querySuccess], timeout: 1)
58+
await fulfillment(of: [querySuccess], timeout: 1)
5959
}
6060

61-
func testQueryPendingMutationEventsForModelIds() {
61+
func testQueryPendingMutationEventsForModelIds() async {
6262
let mutationEvent1 = generateRandomMutationEvent()
6363
let mutationEvent2 = generateRandomMutationEvent()
6464

@@ -70,7 +70,7 @@ class MutationEventQueryTests: BaseDataStoreTests {
7070
}
7171
saveMutationEvent1.fulfill()
7272
}
73-
wait(for: [saveMutationEvent1], timeout: 1)
73+
await fulfillment(of: [saveMutationEvent1], timeout: 1)
7474

7575
let saveMutationEvent2 = expectation(description: "save mutationEvent1 success")
7676
storageAdapter.save(mutationEvent2) { result in
@@ -80,7 +80,7 @@ class MutationEventQueryTests: BaseDataStoreTests {
8080
}
8181
saveMutationEvent2.fulfill()
8282
}
83-
wait(for: [saveMutationEvent2], timeout: 1)
83+
await fulfillment(of: [saveMutationEvent2], timeout: 1)
8484

8585
let querySuccess = expectation(description: "query for metadata success")
8686
var mutationEvents = [mutationEvent1]
@@ -98,7 +98,7 @@ class MutationEventQueryTests: BaseDataStoreTests {
9898
}
9999
}
100100

101-
wait(for: [querySuccess], timeout: 1)
101+
await fulfillment(of: [querySuccess], timeout: 3)
102102
}
103103

104104
private func generateRandomMutationEvent() -> MutationEvent {

0 commit comments

Comments
 (0)