-
Notifications
You must be signed in to change notification settings - Fork 24
/
AnthropicCreateChatCompletionStreamedWithOpenAIAdapter.scala
78 lines (67 loc) · 2.43 KB
/
AnthropicCreateChatCompletionStreamedWithOpenAIAdapter.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
package io.cequence.openaiscala.examples.nonopenai
import akka.NotUsed
import akka.stream.scaladsl.{RestartSource, Sink, Source}
import io.cequence.openaiscala.OpenAIScalaClientException
import io.cequence.openaiscala.domain.settings.CreateChatCompletionSettings
import io.cequence.openaiscala.domain.{NonOpenAIModelId, SystemMessage, UserMessage}
import io.cequence.openaiscala.examples.ExampleBase
import io.cequence.openaiscala.service.StreamedServiceTypes.OpenAIChatCompletionStreamedService
import org.slf4j.LoggerFactory
import scala.concurrent.Future
import scala.concurrent.duration.{DurationDouble, DurationInt}
// requires `openai-scala-anthropic-client` as a dependency and `ANTHROPIC_API_KEY` environment variable to be set
object AnthropicCreateChatCompletionStreamedWithOpenAIAdapter
extends ExampleBase[OpenAIChatCompletionStreamedService] {
private val logger = LoggerFactory.getLogger(this.getClass)
override val service: OpenAIChatCompletionStreamedService =
ChatCompletionProvider.anthropic()
private val messages = Seq(
SystemMessage("You are a helpful assistant."),
UserMessage("What is the weather like in Norway?")
)
@volatile var attemptCounter = 0
private val maxAttempts = 3
override protected def run: Future[_] = {
def createSource(): Source[String, NotUsed] =
service
.createChatCompletionStreamed(
messages = messages,
settings = CreateChatCompletionSettings(
model = NonOpenAIModelId.claude_3_5_sonnet_20240620
)
)
.map(
_.choices.headOption.flatMap(_.delta.content).getOrElse("")
)
val sourceWithRetry: Source[String, _] = RestartSource.onFailuresWithBackoff(
minBackoff = 0.5.seconds,
maxBackoff = 20.seconds,
randomFactor = 0.2,
maxRestarts = 3
) { () =>
attemptCounter += 1
if (attemptCounter <= maxAttempts) {
createSource()
} else {
Source.failed(new OpenAIScalaClientException("Max attempts reached"))
}
}
sourceWithRetry
.watchTermination()(
(
_,
done
) => {
done.onComplete {
case scala.util.Success(_) =>
logger.debug("Response completed successfully.")
case scala.util.Failure(ex) =>
logger.error("Response failed with an exception.", ex)
}
}
)
.runWith(
Sink.foreach(print)
)
}
}