Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 2 additions & 11 deletions Sources/StateGraph/Observation/Node+Observe.swift
Original file line number Diff line number Diff line change
Expand Up @@ -93,17 +93,8 @@ extension Node {
- Note: The sequence starts with the current value, then emits subsequent changes
- Note: The sequence continues indefinitely until cancelled or the node is deallocated
*/
public func observe() -> AsyncStartWithSequence<AsyncMapSequence<AsyncStream<Void>, Self.Value>> {

let stream = withStateGraphTrackingStream {
_ = self.wrappedValue
}
.map {
self.wrappedValue
}
.startWith(self.wrappedValue)

return stream
public func observe() -> AsyncStream<Self.Value> {
withStateGraphTrackingStream { self.wrappedValue }
}

}
Expand Down
36 changes: 22 additions & 14 deletions Sources/StateGraph/Observation/withTracking.swift
Original file line number Diff line number Diff line change
Expand Up @@ -61,25 +61,33 @@ func withContinuousStateGraphTracking<R>(
}
}

func withStateGraphTrackingStream(
apply: @escaping () -> Void
) -> AsyncStream<Void> {

AsyncStream<Void> { (continuation: AsyncStream<Void>.Continuation) in

public func withStateGraphTrackingStream<T>(
apply: @escaping () -> T,
isolation: isolated (any Actor)? = #isolation
) -> AsyncStream<T> {

AsyncStream<T> { (continuation: AsyncStream<T>.Continuation) in

let isCancelled = OSAllocatedUnfairLock(initialState: false)

continuation.onTermination = { termination in
isCancelled.withLock { $0 = true }
}

withContinuousStateGraphTracking(
apply: {
let value = apply()
continuation.yield(value)
},
didChange: {
if isCancelled.withLock({ $0 }) {
return .stop
}
return .next
},
isolation: isolation
)

withContinuousStateGraphTracking(apply: apply) {
continuation.yield()
if isCancelled.withLock({ $0 }) {
return .stop
}
return .next
}
}
}

Expand Down
146 changes: 120 additions & 26 deletions Tests/StateGraphTests/Tests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -372,41 +372,42 @@ struct StateGraphTrackingTests {
}
}

@Suite
struct GraphViewAdvancedTests {

final class NestedModel: Sendable {
final class NestedModel: Sendable {

@GraphStored
var counter: Int = 0
@GraphStored
var counter: Int = 0

@GraphStored
var subModel: SubModel?
@GraphStored
var subModel: SubModel?

init() {
self.subModel = nil
}
init() {
self.subModel = nil
}

func incrementCounter() {
counter += 1
}
func incrementCounter() {
counter += 1
}

func createSubModel() {
subModel = SubModel()
}
func createSubModel() {
subModel = SubModel()
}
}

final class SubModel: Sendable {
@GraphStored
var value: String = "default"
final class SubModel: Sendable {
@GraphStored
var value: String = "default"

init() {}
init() {}

func updateValue(_ newValue: String) {
value = newValue
}
func updateValue(_ newValue: String) {
value = newValue
}
}

@Suite
struct GraphViewAdvancedTests {

@Test func nested_model_tracking() async {
let model = NestedModel()
model.createSubModel()
Expand Down Expand Up @@ -444,12 +445,56 @@ struct GraphViewAdvancedTests {
}
}

}

import Foundation

@Suite
struct StreamTests {

@Test func projection_tracking() async {
let model = NestedModel()

await confirmation(expectedCount: 4) { c in
let receivedValues = OSAllocatedUnfairLock<[Int]>(initialState: [])

let task = Task {
// Test that withStateGraphTrackingStream now returns projected values directly
for await value in withStateGraphTrackingStream(apply: {
model.counter // Returns Int directly
}) {
receivedValues.withLock { $0.append(value) }
c.confirm()
if value == 3 {
break
}
}
}

try! await Task.sleep(for: .milliseconds(100))

model.counter = 1
try! await Task.sleep(for: .milliseconds(100))

model.counter = 2
try! await Task.sleep(for: .milliseconds(100))

model.counter = 3
try! await Task.sleep(for: .milliseconds(100))

await task.value

// Verify we received the projected values: initial (0) + 3 changes
#expect(receivedValues.withLock { $0 } == [0, 1, 2, 3])
}
}

@Test func continuous_tracking() async {
let model = NestedModel()

await confirmation(expectedCount: 3) { c in
await confirmation(expectedCount: 4) { c in

let expectation = OSAllocatedUnfairLock<Int>(initialState: -1)
let expectation = OSAllocatedUnfairLock<Int>(initialState: 0)

let task = Task {
for await _ in withStateGraphTrackingStream(apply: {
Expand Down Expand Up @@ -485,5 +530,54 @@ struct GraphViewAdvancedTests {
}

}
}

@Test func continuous_tracking_main() async {
let model = NestedModel()

await confirmation(expectedCount: 4) { c in

let expectation = OSAllocatedUnfairLock<Int>(initialState: 0)

let task = Task { @MainActor in

let stream = withStateGraphTrackingStream(apply: {
assert(Thread.isMainThread, "Because this stream has been created on MainActor.")
_ = model.counter
})

Task.detached {
for await _ in stream {
print(model.counter)
#expect(model.counter == expectation.withLock { $0 })
c.confirm()
if model.counter == 3 {
break
}
}
}

}

try! await Task.sleep(for: .milliseconds(100))

// Trigger updates
expectation.withLock { $0 = 1 }
model.counter = expectation.withLock { $0 }

try! await Task.sleep(for: .milliseconds(100))

expectation.withLock { $0 = 2 }
model.counter = expectation.withLock { $0 }
try! await Task.sleep(for: .milliseconds(100))

expectation.withLock { $0 = 3 }
model.counter = expectation.withLock { $0 }

try! await Task.sleep(for: .milliseconds(100))

await task.value
}

}

}
Loading