diff --git a/Sources/StateGraph/Observation/Node+Observe.swift b/Sources/StateGraph/Observation/Node+Observe.swift index d798039..3b680b3 100644 --- a/Sources/StateGraph/Observation/Node+Observe.swift +++ b/Sources/StateGraph/Observation/Node+Observe.swift @@ -93,17 +93,8 @@ extension Node { - Note: The sequence starts with the current value, then emits subsequent changes - Note: The sequence continues indefinitely until cancelled or the node is deallocated */ - public func observe() -> AsyncStartWithSequence, Self.Value>> { - - let stream = withStateGraphTrackingStream { - _ = self.wrappedValue - } - .map { - self.wrappedValue - } - .startWith(self.wrappedValue) - - return stream + public func observe() -> AsyncStream { + withStateGraphTrackingStream { self.wrappedValue } } } diff --git a/Sources/StateGraph/Observation/withTracking.swift b/Sources/StateGraph/Observation/withTracking.swift index 0020b6c..76ce0a3 100644 --- a/Sources/StateGraph/Observation/withTracking.swift +++ b/Sources/StateGraph/Observation/withTracking.swift @@ -61,25 +61,33 @@ func withContinuousStateGraphTracking( } } -func withStateGraphTrackingStream( - apply: @escaping () -> Void -) -> AsyncStream { - - AsyncStream { (continuation: AsyncStream.Continuation) in - +public func withStateGraphTrackingStream( + apply: @escaping () -> T, + isolation: isolated (any Actor)? = #isolation +) -> AsyncStream { + + AsyncStream { (continuation: AsyncStream.Continuation) in + let isCancelled = OSAllocatedUnfairLock(initialState: false) - + continuation.onTermination = { termination in isCancelled.withLock { $0 = true } } + + withContinuousStateGraphTracking( + apply: { + let value = apply() + continuation.yield(value) + }, + didChange: { + if isCancelled.withLock({ $0 }) { + return .stop + } + return .next + }, + isolation: isolation + ) - withContinuousStateGraphTracking(apply: apply) { - continuation.yield() - if isCancelled.withLock({ $0 }) { - return .stop - } - return .next - } } } diff --git a/Tests/StateGraphTests/Tests.swift b/Tests/StateGraphTests/Tests.swift index 0e40abc..aa13d9a 100644 --- a/Tests/StateGraphTests/Tests.swift +++ b/Tests/StateGraphTests/Tests.swift @@ -372,41 +372,42 @@ struct StateGraphTrackingTests { } } -@Suite -struct GraphViewAdvancedTests { - final class NestedModel: Sendable { +final class NestedModel: Sendable { - @GraphStored - var counter: Int = 0 + @GraphStored + var counter: Int = 0 - @GraphStored - var subModel: SubModel? + @GraphStored + var subModel: SubModel? - init() { - self.subModel = nil - } + init() { + self.subModel = nil + } - func incrementCounter() { - counter += 1 - } + func incrementCounter() { + counter += 1 + } - func createSubModel() { - subModel = SubModel() - } + func createSubModel() { + subModel = SubModel() } +} - final class SubModel: Sendable { - @GraphStored - var value: String = "default" +final class SubModel: Sendable { + @GraphStored + var value: String = "default" - init() {} + init() {} - func updateValue(_ newValue: String) { - value = newValue - } + func updateValue(_ newValue: String) { + value = newValue } +} +@Suite +struct GraphViewAdvancedTests { + @Test func nested_model_tracking() async { let model = NestedModel() model.createSubModel() @@ -444,12 +445,56 @@ struct GraphViewAdvancedTests { } } +} + +import Foundation + +@Suite +struct StreamTests { + + @Test func projection_tracking() async { + let model = NestedModel() + + await confirmation(expectedCount: 4) { c in + let receivedValues = OSAllocatedUnfairLock<[Int]>(initialState: []) + + let task = Task { + // Test that withStateGraphTrackingStream now returns projected values directly + for await value in withStateGraphTrackingStream(apply: { + model.counter // Returns Int directly + }) { + receivedValues.withLock { $0.append(value) } + c.confirm() + if value == 3 { + break + } + } + } + + try! await Task.sleep(for: .milliseconds(100)) + + model.counter = 1 + try! await Task.sleep(for: .milliseconds(100)) + + model.counter = 2 + try! await Task.sleep(for: .milliseconds(100)) + + model.counter = 3 + try! await Task.sleep(for: .milliseconds(100)) + + await task.value + + // Verify we received the projected values: initial (0) + 3 changes + #expect(receivedValues.withLock { $0 } == [0, 1, 2, 3]) + } + } + @Test func continuous_tracking() async { let model = NestedModel() - await confirmation(expectedCount: 3) { c in + await confirmation(expectedCount: 4) { c in - let expectation = OSAllocatedUnfairLock(initialState: -1) + let expectation = OSAllocatedUnfairLock(initialState: 0) let task = Task { for await _ in withStateGraphTrackingStream(apply: { @@ -485,5 +530,54 @@ struct GraphViewAdvancedTests { } } -} + + @Test func continuous_tracking_main() async { + let model = NestedModel() + + await confirmation(expectedCount: 4) { c in + + let expectation = OSAllocatedUnfairLock(initialState: 0) + + let task = Task { @MainActor in + + let stream = withStateGraphTrackingStream(apply: { + assert(Thread.isMainThread, "Because this stream has been created on MainActor.") + _ = model.counter + }) + + Task.detached { + for await _ in stream { + print(model.counter) + #expect(model.counter == expectation.withLock { $0 }) + c.confirm() + if model.counter == 3 { + break + } + } + } + + } + + try! await Task.sleep(for: .milliseconds(100)) + + // Trigger updates + expectation.withLock { $0 = 1 } + model.counter = expectation.withLock { $0 } + + try! await Task.sleep(for: .milliseconds(100)) + + expectation.withLock { $0 = 2 } + model.counter = expectation.withLock { $0 } + try! await Task.sleep(for: .milliseconds(100)) + expectation.withLock { $0 = 3 } + model.counter = expectation.withLock { $0 } + + try! await Task.sleep(for: .milliseconds(100)) + + await task.value + } + + } + +}