Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IOLocal propagation for unsafe access #3636

Merged
merged 39 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
0b88c01
POC thread-local iolocals
armanbilge May 16, 2023
db743e2
Simplify and optimize
armanbilge May 16, 2023
716ef32
Special-case for `WorkerThread`
armanbilge May 17, 2023
0a69caf
Load locals in `unsafeRunFiber`
armanbilge May 17, 2023
2775064
Dump locals in more places
armanbilge May 18, 2023
270764f
Refactor `IOLocal`
armanbilge May 21, 2023
d55489d
Use new `IOLocal` APIs in `IOLocals`
armanbilge May 21, 2023
2cf72a5
Mark `IOLocal` methods as `final`
armanbilge May 21, 2023
cb3859d
Add `IOLocalsSpec`
armanbilge Jun 10, 2023
7dce01c
Rename property to `ioLocalPropagation` and fixes
armanbilge Jun 28, 2023
5e171ac
Bump base version
armanbilge Jun 28, 2023
c2f312d
Add files I forgot tocommit :)
armanbilge Jun 28, 2023
638930d
Merge branch 'series/3.x' into topic/thread-local-iolocal
armanbilge Jun 28, 2023
9174c6a
Add MiMa filters
armanbilge Jun 28, 2023
1987e3a
Fix scaladoc links
armanbilge Jun 28, 2023
02a43a6
Alias the disambiguations
armanbilge Jun 28, 2023
a7bf748
Copy locals back out after blocking unsafe run
armanbilge Sep 5, 2023
145fc0e
Merge remote-tracking branch 'upstream/series/3.x' into topic/thread-…
armanbilge Sep 5, 2023
fa99a5c
Expose status of `IOLocal` propagation
armanbilge Sep 25, 2023
6cad03c
`propagating` -> `arePropagating`
armanbilge Sep 29, 2023
bb5d4b1
Merge branch 'series/3.x' into topic/thread-local-iolocal
armanbilge Sep 30, 2023
7517755
Use `real` instead of `ticked`
armanbilge Sep 30, 2023
8d8e004
Formatting
armanbilge Sep 30, 2023
3589db4
Try keeping the current fiber as a thread-local instead
armanbilge Sep 30, 2023
522677e
Revert spurious whitespace changes
armanbilge Sep 30, 2023
6cc4d38
Merge branch 'series/3.x' into topic/thread-local-iolocal
armanbilge May 27, 2024
ac88480
Update headers
armanbilge May 27, 2024
49e5c30
Update platform headers
armanbilge May 27, 2024
925f504
Remove unused class
armanbilge May 27, 2024
d63a6ff
Expose `IOLocal` propagation as a `ThreadLocal`
armanbilge Jun 4, 2024
d4549fb
`unsafeToThreadLocal()` throws if propagation disabled
armanbilge Jun 4, 2024
2502045
Add scaladoc
armanbilge Jun 5, 2024
535fc8a
Factor out to JVM-only
armanbilge Jun 5, 2024
d854799
Bikeshed API and docs
armanbilge Jun 5, 2024
f070552
Formatting
armanbilge Jun 5, 2024
2cf1d8a
Delete dead code
armanbilge Jun 5, 2024
0eec9dd
Document `ThreadLocal` propagation
armanbilge Aug 5, 2024
af84973
Merge branch 'series/3.x' into topic/thread-local-iolocal
armanbilge Nov 14, 2024
1adf368
Merge branch 'series/3.x' into topic/thread-local-iolocal
armanbilge Nov 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -650,7 +650,15 @@ lazy val core = crossProject(JSPlatform, JVMPlatform, NativePlatform)
"cats.effect.unsafe.IORuntimeBuilder.this"),
// introduced by #3695, which enabled fiber dumps on native
ProblemFilters.exclude[MissingClassProblem](
"cats.effect.unsafe.FiberMonitorCompanionPlatform")
"cats.effect.unsafe.FiberMonitorCompanionPlatform"),
// introduced by #3636, IOLocal propagation
// IOLocal is a sealed trait
ProblemFilters.exclude[ReversedMissingMethodProblem]("cats.effect.IOLocal.getOrDefault"),
ProblemFilters.exclude[ReversedMissingMethodProblem]("cats.effect.IOLocal.set"),
ProblemFilters.exclude[ReversedMissingMethodProblem]("cats.effect.IOLocal.reset"),
ProblemFilters.exclude[ReversedMissingMethodProblem]("cats.effect.IOLocal.lens"),
// this filter is particulary terrible, because it can also mask real issues :(
ProblemFilters.exclude[DirectMissingMethodProblem]("cats.effect.IOLocal.lens")
) ++ {
if (tlIsScala3.value) {
// Scala 3 specific exclusions
Expand Down Expand Up @@ -890,7 +898,11 @@ lazy val tests: CrossProject = crossProject(JSPlatform, JVMPlatform, NativePlatf
scalacOptions ~= { _.filterNot(_.startsWith("-P:scalajs:mapSourceURI")) }
)
.jvmSettings(
Test / fork := true
Test / fork := true,
Test / javaOptions ++= Seq(
"-Dcats.effect.ioLocalPropagation=true",
s"-Dsbt.classpath=${(Test / fullClasspath).value.map(_.data.getAbsolutePath).mkString(File.pathSeparator)}"
)
)
.nativeSettings(
Compile / mainClass := Some("catseffect.examples.NativeRunner")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,6 @@ private object IOFiberConstants {
final val CedeR = 6
final val AutoCedeR = 7
final val DoneR = 8

final val ioLocalPropagation = false
djspiewak marked this conversation as resolved.
Show resolved Hide resolved
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
/*
* Copyright 2020-2023 Typelevel
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package cats.effect

private object IOLocalsConstants {
final val ioLocalPropagation = false
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,5 @@ private[unsafe] sealed abstract class WorkerThread[P] private () extends Thread
private[unsafe] def isOwnedBy(threadPool: WorkStealingThreadPool[_]): Boolean
private[unsafe] def monitor(fiber: Runnable): WeakBag.Handle
private[unsafe] def index: Int
private[unsafe] var ioLocalState: IOLocalState
}
2 changes: 2 additions & 0 deletions core/jvm/src/main/java/cats/effect/IOFiberConstants.java
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,6 @@ final class IOFiberConstants {
static final byte CedeR = 6;
static final byte AutoCedeR = 7;
static final byte DoneR = 8;

static final boolean ioLocalPropagation = Boolean.getBoolean("cats.effect.ioLocalPropagation");
}
22 changes: 22 additions & 0 deletions core/jvm/src/main/java/cats/effect/unsafe/IOLocalsConstants.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
/*
* Copyright 2020-2023 Typelevel
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package cats.effect.unsafe;

// defined in Java since Scala doesn't let us define static fields
final class IOLocalsConstants {
static final boolean ioLocalPropagation = Boolean.getBoolean("cats.effect.ioLocalPropagation");

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it would be very helpful if this was exposed to the end user in some way so that they can check if it's enabled and potentially raise an error if not (for example, we don't want people trying to use an IOLocal-backed ContextStorageProvider and its associated Local in otle4s if this value is false)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great point. fa99a5c, wdyt?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm, cheers!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very cool!

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this be a more explicit option type

You mean so that it can be configured by user-code at runtime?

The advantage of using system properties with static final fields is that their values are constant at JVM startup, which allows the JIT to optimize those branches. Allowing dynamic configuration would circumvent that.

Or if I missed your point, sorry, please explain 😅 note we also have a lot of configuration available in IORuntimeConfig class which can be configured programmatically at runtime.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@armanbilge Oh, I didn't mean to configure it at runtime. I meant that it would be good to have a proper type for setting local propagation instead of having just Boolean.
But, I got your point of static final, so please ignore what I said. 😅

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant that it would be good to have a proper type for setting local propagation instead of having just Boolean.

Got it, thanks :) Monix also seems to be using just a Boolean for setting local propagation anyway. The difference is that it allows it to be configured dynamically per task.

https://github.com/monix/monix/blob/952034ad4884b64fb4241d78f33fdef07f267157/monix-eval/shared/src/main/scala/monix/eval/Task.scala#L4406

}
5 changes: 4 additions & 1 deletion core/jvm/src/main/scala/cats/effect/IOPlatform.scala
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ abstract private[effect] class IOPlatform[+A] extends Serializable { self: IO[A]
implicit runtime: unsafe.IORuntime): Option[A] = {
val queue = new ArrayBlockingQueue[Either[Throwable, A]](1)

unsafeRunAsync { r =>
val fiber = unsafeRunAsyncImpl { r =>
queue.offer(r)
()
}
Expand All @@ -82,6 +82,9 @@ abstract private[effect] class IOPlatform[+A] extends Serializable { self: IO[A]
} catch {
case _: InterruptedException =>
None
} finally {
if (IOFiberConstants.ioLocalPropagation)
unsafe.IOLocals.setState(fiber.getLocalState())
}
}

Expand Down
2 changes: 2 additions & 0 deletions core/jvm/src/main/scala/cats/effect/unsafe/WorkerThread.scala
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ private final class WorkerThread[P](
private val indexTransfer: LinkedTransferQueue[Integer] = new LinkedTransferQueue()
private[this] val runtimeBlockingExpiration: Duration = pool.runtimeBlockingExpiration

private[unsafe] var ioLocalState: IOLocalState = IOLocalState.empty

val nameIndex: Int = pool.blockedWorkerThreadNamingIndex.getAndIncrement()

// Constructor code.
Expand Down
10 changes: 7 additions & 3 deletions core/shared/src/main/scala/cats/effect/IO.scala
Original file line number Diff line number Diff line change
Expand Up @@ -912,6 +912,12 @@ sealed abstract class IO[+A] private () extends IOPlatform[A] {
*/
def unsafeRunAsync(cb: Either[Throwable, A] => Unit)(
implicit runtime: unsafe.IORuntime): Unit = {
unsafeRunAsyncImpl(cb)
()
}

private[effect] def unsafeRunAsyncImpl(cb: Either[Throwable, A] => Unit)(
implicit runtime: unsafe.IORuntime): IOFiber[A @uncheckedVariance] =
djspiewak marked this conversation as resolved.
Show resolved Hide resolved
unsafeRunFiber(
cb(Left(new CancellationException("The fiber was canceled"))),
t => {
Expand All @@ -922,8 +928,6 @@ sealed abstract class IO[+A] private () extends IOPlatform[A] {
},
a => cb(Right(a))
)
()
}

def unsafeRunAsyncOutcome(cb: Outcome[Id, Throwable, A @uncheckedVariance] => Unit)(
implicit runtime: unsafe.IORuntime): Unit = {
Expand Down Expand Up @@ -1026,7 +1030,7 @@ sealed abstract class IO[+A] private () extends IOPlatform[A] {
implicit runtime: unsafe.IORuntime): IOFiber[A @uncheckedVariance] = {

val fiber = new IOFiber[A](
Map.empty,
if (IOFiberConstants.ioLocalPropagation) unsafe.IOLocals.getState else Map.empty,
oc =>
oc.fold(
{
Expand Down
51 changes: 51 additions & 0 deletions core/shared/src/main/scala/cats/effect/IOFiber.scala
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ private final class IOFiber[A](
@volatile
private[this] var outcome: OutcomeIO[A] = _

def getLocalState(): IOLocalState = localState

override def run(): Unit = {
// insert a read barrier after every async boundary
readBarrier()
Expand Down Expand Up @@ -250,6 +252,10 @@ private final class IOFiber[A](
pushTracingEvent(cur.event)
}

if (ioLocalPropagation) {
IOLocals.setState(localState)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dumb question: can't we simply do this when we get scheduled on a thread? We know when we're on a thread and we know when we get off of it, so can't we simply set and clear the state respectively at those points?

Copy link
Member Author

@armanbilge armanbilge Sep 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No we can't, unless we unify how the state is represented. Currently it's a var to an immutable map in the fiber and also in the thread. While the fiber is running its copy of the var may be updated effectually in the runloop so the thread-local copy would need to be kept in sync with that. Or we could drive all updates through the thread-local copy of the var, but then there would be a penalty for accessing it esp. if we are not running on a worker thread.

Putting aside technical issues, nobody should be unsafely messing about with IOLocals outside of a properly suspended side-effect block and this strategy enforces that.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What we can do is set the current fiber in a thread local every time we get scheduled on a thread. Then the unsafe IOLocals manipulations can operate on the state via the current fiber and we don't need to pay the penalty for every delay block. Based on the benchmarks this strategy is seeming more attractive 😅

Note this would leave the fiber's IOLocal state exposed to unsafe manipulations outside of delay blocks.

}

var error: Throwable = null
val r =
try cur.thunk()
Expand All @@ -260,6 +266,10 @@ private final class IOFiber[A](
onFatalFailure(t)
}

if (ioLocalPropagation) {
localState = IOLocals.getAndClearState()
}

val next =
if (error == null) succeeded(r, 0)
else failed(error, 0)
Expand Down Expand Up @@ -324,6 +334,10 @@ private final class IOFiber[A](
pushTracingEvent(delay.event)
}

if (ioLocalPropagation) {
IOLocals.setState(localState)
}

// this code is inlined in order to avoid two `try` blocks
var error: Throwable = null
val result =
Expand All @@ -335,6 +349,10 @@ private final class IOFiber[A](
onFatalFailure(t)
}

if (ioLocalPropagation) {
localState = IOLocals.getAndClearState()
}

val nextIO = if (error == null) succeeded(result, 0) else failed(error, 0)
runLoop(nextIO, nextCancelation - 1, nextAutoCede)

Expand Down Expand Up @@ -391,6 +409,10 @@ private final class IOFiber[A](
pushTracingEvent(delay.event)
}

if (ioLocalPropagation) {
IOLocals.setState(localState)
}

// this code is inlined in order to avoid two `try` blocks
val result =
try f(delay.thunk())
Expand All @@ -401,6 +423,10 @@ private final class IOFiber[A](
onFatalFailure(t)
}

if (ioLocalPropagation) {
localState = IOLocals.getAndClearState()
}

runLoop(result, nextCancelation - 1, nextAutoCede)

case 3 =>
Expand Down Expand Up @@ -446,6 +472,10 @@ private final class IOFiber[A](
pushTracingEvent(delay.event)
}

if (ioLocalPropagation) {
IOLocals.setState(localState)
}

// this code is inlined in order to avoid two `try` blocks
var error: Throwable = null
val result =
Expand All @@ -460,6 +490,10 @@ private final class IOFiber[A](
onFatalFailure(t)
}

if (ioLocalPropagation) {
localState = IOLocals.getAndClearState()
}

val next =
if (error == null) succeeded(Right(result), 0) else succeeded(Left(error), 0)
runLoop(next, nextCancelation - 1, nextAutoCede)
Expand Down Expand Up @@ -973,6 +1007,10 @@ private final class IOFiber[A](
if (ec.isInstanceOf[WorkStealingThreadPool[_]]) {
val wstp = ec.asInstanceOf[WorkStealingThreadPool[_]]
if (wstp.canExecuteBlockingCode()) {
if (ioLocalPropagation) {
IOLocals.setState(localState)
}

var error: Throwable = null
val r =
try {
Expand All @@ -984,6 +1022,10 @@ private final class IOFiber[A](
onFatalFailure(t)
}

if (ioLocalPropagation) {
localState = IOLocals.getAndClearState()
}

val next = if (error eq null) succeeded(r, 0) else failed(error, 0)
runLoop(next, nextCancelation, nextAutoCede)
} else {
Expand Down Expand Up @@ -1390,6 +1432,11 @@ private final class IOFiber[A](
var error: Throwable = null
val cur = resumeIO.asInstanceOf[Blocking[Any]]
resumeIO = null

if (ioLocalPropagation) {
IOLocals.setState(localState)
}

val r =
try cur.thunk()
catch {
Expand All @@ -1399,6 +1446,10 @@ private final class IOFiber[A](
onFatalFailure(t)
}

if (ioLocalPropagation) {
localState = IOLocals.getAndClearState()
}

if (isStackTracing) {
// Remove the reference to the fiber monitor handle
objectState.pop().asInstanceOf[WeakBag.Handle].deregister()
Expand Down
Loading