Skip to content

Commit

Permalink
Amazon Bedrock converse API module supports custom BedrockRuntimeClie…
Browse files Browse the repository at this point in the history
…nt and BedrockRuntimeAsyncClient.
  • Loading branch information
wmz7year committed Jun 8, 2024
1 parent 89f9c89 commit 40930c7
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 115 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,14 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.Assert;

import reactor.core.publisher.Flux;
import reactor.core.publisher.Sinks;
import reactor.core.publisher.Sinks.EmitFailureHandler;
import reactor.core.publisher.Sinks.EmitResult;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.auth.credentials.ProfileCredentialsProvider;
import software.amazon.awssdk.core.document.Document;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClient;
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient;
import software.amazon.awssdk.services.bedrockruntime.model.ConverseRequest;
Expand Down Expand Up @@ -69,118 +65,29 @@ public class BedrockConverseApi {

private static final Logger logger = LoggerFactory.getLogger(BedrockConverseApi.class);

private final Region region;
private final BedrockRuntimeClient bedrockRuntimeClient;

private final BedrockRuntimeClient client;

private final BedrockRuntimeAsyncClient clientStreaming;
private final BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient;

private final RetryTemplate retryTemplate;

/**
* Create a new BedrockConverseApi instance using default credentials provider.
*
* @param region The AWS region to use.
*/
public BedrockConverseApi(String region) {
this(ProfileCredentialsProvider.builder().build(), region, Duration.ofMinutes(5));
}

/**
* Create a new BedrockConverseApi instance using default credentials provider.
*
* @param region The AWS region to use.
* @param timeout The timeout to use.
*/
public BedrockConverseApi(String region, Duration timeout) {
this(ProfileCredentialsProvider.builder().build(), region, timeout);
}

/**
* Create a new BedrockConverseApi instance using the provided credentials provider,
* region.
*
* @param credentialsProvider The credentials provider to connect to AWS.
* @param region The AWS region to use.
*/
public BedrockConverseApi(AwsCredentialsProvider credentialsProvider, String region) {
this(credentialsProvider, region, Duration.ofMinutes(5));
}

/**
* Create a new BedrockConverseApi instance using the provided credentials provider,
* region.
* Create a new BedrockConverseApi instance using the provided AWS Bedrock clients and the RetryTemplate.
*
* @param credentialsProvider The credentials provider to connect to AWS.
* @param region The AWS region to use.
* @param timeout Configure the amount of time to allow the client to complete the
* execution of an API call. This timeout covers the entire client execution except
* for marshalling. This includes request handler execution, all HTTP requests
* including retries, unmarshalling, etc. This value should always be positive, if
* present.
*/
public BedrockConverseApi(AwsCredentialsProvider credentialsProvider, String region, Duration timeout) {
this(credentialsProvider, Region.of(region), timeout);
}

/**
* Create a new BedrockConverseApi instance using the provided credentials provider,
* region.
*
* @param credentialsProvider The credentials provider to connect to AWS.
* @param region The AWS region to use.
* @param timeout Configure the amount of time to allow the client to complete the
* execution of an API call. This timeout covers the entire client execution except
* for marshalling. This includes request handler execution, all HTTP requests
* including retries, unmarshalling, etc. This value should always be positive, if
* present.
*/
public BedrockConverseApi(AwsCredentialsProvider credentialsProvider, Region region, Duration timeout) {
this(credentialsProvider, region, timeout, RetryUtils.DEFAULT_RETRY_TEMPLATE);
}

/**
* Create a new BedrockConverseApi instance using the provided credentials provider,
* region
*
* @param credentialsProvider The credentials provider to connect to AWS.
* @param region The AWS region to use.
* @param timeout Configure the amount of time to allow the client to complete the
* execution of an API call. This timeout covers the entire client execution except
* for marshalling. This includes request handler execution, all HTTP requests
* including retries, unmarshalling, etc. This value should always be positive, if
* present.
* @param bedrockRuntimeClient The AWS BedrockRuntimeClient instance.
* @param bedrockRuntimeAsyncClient The AWS BedrockRuntimeAsyncClient instance.
* @param retryTemplate The retry template used to retry the Amazon Bedrock Converse
* API calls.
*/
public BedrockConverseApi(AwsCredentialsProvider credentialsProvider, Region region, Duration timeout,
public BedrockConverseApi(BedrockRuntimeClient bedrockRuntimeClient, BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient,
RetryTemplate retryTemplate) {
Assert.notNull(credentialsProvider, "Credentials provider must not be null");
Assert.notNull(region, "Region must not be empty");
Assert.notNull(timeout, "Timeout must not be null");
Assert.notNull(bedrockRuntimeClient, "bedrockRuntimeClient must not be null");
Assert.notNull(bedrockRuntimeAsyncClient, "bedrockRuntimeAsyncClient must not be null");
Assert.notNull(retryTemplate, "RetryTemplate must not be null");

this.region = region;
this.bedrockRuntimeClient = bedrockRuntimeClient;
this.bedrockRuntimeAsyncClient = bedrockRuntimeAsyncClient;
this.retryTemplate = retryTemplate;

this.client = BedrockRuntimeClient.builder()
.region(this.region)
.credentialsProvider(credentialsProvider)
.overrideConfiguration(c -> c.apiCallTimeout(timeout))
.build();

this.clientStreaming = BedrockRuntimeAsyncClient.builder()
.region(this.region)
.credentialsProvider(credentialsProvider)
.overrideConfiguration(c -> c.apiCallTimeout(timeout))
.build();
}

/**
* @return The AWS region.
*/
public Region getRegion() {
return this.region;
}

/**
Expand Down Expand Up @@ -215,7 +122,7 @@ public ConverseResponse converse(ConverseRequest converseRequest) {
Assert.notNull(converseRequest, "'converseRequest' must not be null");

return this.retryTemplate.execute(ctx -> {
return client.converse(converseRequest);
return bedrockRuntimeClient.converse(converseRequest);
});
}

Expand Down Expand Up @@ -280,7 +187,7 @@ public Flux<ConverseStreamOutput> converseStream(ConverseStreamRequest converseS
})
.build();

clientStreaming.converseStream(converseStreamRequest, responseHandler);
bedrockRuntimeAsyncClient.converseStream(converseStreamRequest, responseHandler);

return eventSink.asFlux();
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.regions.providers.AwsRegionProvider;
import software.amazon.awssdk.regions.providers.DefaultAwsRegionProviderChain;
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClient;
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient;

import org.springframework.boot.autoconfigure.condition.ConditionalOnBean;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;
Expand Down Expand Up @@ -60,6 +63,32 @@ public AwsRegionProvider regionProvider(BedrockAwsConnectionProperties propertie
return DefaultAwsRegionProviderChain.builder().build();
}

@Bean
@ConditionalOnMissingBean
@ConditionalOnBean({ AwsCredentialsProvider.class, AwsRegionProvider.class })
public BedrockRuntimeClient bedrockRuntimeClient(AwsCredentialsProvider credentialsProvider,
AwsRegionProvider regionProvider, BedrockAwsConnectionProperties properties) {

return BedrockRuntimeClient.builder()
.region(regionProvider.getRegion())
.credentialsProvider(credentialsProvider)
.overrideConfiguration(c -> c.apiCallTimeout(properties.getTimeout()))
.build();
}

@Bean
@ConditionalOnMissingBean
@ConditionalOnBean({ AwsCredentialsProvider.class, AwsRegionProvider.class })
public BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient(AwsCredentialsProvider credentialsProvider,
AwsRegionProvider regionProvider, BedrockAwsConnectionProperties properties) {

return BedrockRuntimeAsyncClient.builder()
.region(regionProvider.getRegion())
.credentialsProvider(credentialsProvider)
.overrideConfiguration(c -> c.apiCallTimeout(properties.getTimeout()))
.build();
}

/**
* @author Wei Jiang
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@
import org.springframework.context.annotation.Import;
import org.springframework.retry.support.RetryTemplate;

import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.regions.providers.AwsRegionProvider;
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClient;
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient;

Expand All @@ -47,12 +45,10 @@ public class BedrockConverseApiAutoConfiguration {

@Bean
@ConditionalOnMissingBean
@ConditionalOnBean({ AwsCredentialsProvider.class, AwsRegionProvider.class })
public BedrockConverseApi bedrockConverseApi(AwsCredentialsProvider credentialsProvider,
AwsRegionProvider regionProvider, BedrockAwsConnectionProperties awsProperties,
RetryTemplate retryTemplate) {
return new BedrockConverseApi(credentialsProvider, regionProvider.getRegion(), awsProperties.getTimeout(),
retryTemplate);
@ConditionalOnBean({ BedrockRuntimeClient.class, BedrockRuntimeAsyncClient.class })
public BedrockConverseApi bedrockConverseApi(BedrockRuntimeClient bedrockRuntimeClient,
BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient, RetryTemplate retryTemplate) {
return new BedrockConverseApi(bedrockRuntimeClient, bedrockRuntimeAsyncClient, retryTemplate);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.regions.providers.AwsRegionProvider;
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClient;
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient;

/**
* @author Wei Jiang
Expand Down Expand Up @@ -87,6 +89,39 @@ public void autoConfigureWithCustomAWSCredentialAndRegionProvider() {
});
}

@Test
public void autoConfigureBedrockClients() {
new ApplicationContextRunner()
.withPropertyValues("spring.ai.bedrock.aws.access-key=" + System.getenv("AWS_ACCESS_KEY_ID"),
"spring.ai.bedrock.aws.secret-key=" + System.getenv("AWS_SECRET_ACCESS_KEY"),
"spring.ai.bedrock.aws.region=" + Region.US_EAST_1.id())
.withConfiguration(AutoConfigurations.of(TestAutoConfiguration.class))
.run((context) -> {
var bedrockRuntimeClient = context.getBean(BedrockRuntimeClient.class);
var bedrockRuntimeAsyncClient = context.getBean(BedrockRuntimeAsyncClient.class);

assertThat(bedrockRuntimeClient).isNotNull();
assertThat(bedrockRuntimeAsyncClient).isNotNull();
});
}

@Test
public void autoConfigureWithCustomBedrockClients() {
new ApplicationContextRunner()
.withPropertyValues("spring.ai.bedrock.aws.access-key=" + System.getenv("AWS_ACCESS_KEY_ID"),
"spring.ai.bedrock.aws.secret-key=" + System.getenv("AWS_SECRET_ACCESS_KEY"),
"spring.ai.bedrock.aws.region=" + Region.US_EAST_1.id())
.withConfiguration(AutoConfigurations.of(TestAutoConfiguration.class,
CustomBedrockRuntimeClientAutoConfiguration.class))
.run((context) -> {
var bedrockRuntimeClient = context.getBean(BedrockRuntimeClient.class);
var bedrockRuntimeAsyncClient = context.getBean(BedrockRuntimeAsyncClient.class);

assertThat(bedrockRuntimeClient).isNotNull();
assertThat(bedrockRuntimeAsyncClient).isNotNull();
});
}

@EnableConfigurationProperties({ BedrockAwsConnectionProperties.class })
@Import(BedrockAwsConnectionConfiguration.class)
static class TestAutoConfiguration {
Expand Down Expand Up @@ -136,4 +171,29 @@ public Region getRegion() {

}

@AutoConfiguration
static class CustomBedrockRuntimeClientAutoConfiguration {

@Bean
public BedrockRuntimeClient bedrockRuntimeClient(AwsCredentialsProvider credentialsProvider,
AwsRegionProvider regionProvider) {

return BedrockRuntimeClient.builder()
.region(regionProvider.getRegion())
.credentialsProvider(credentialsProvider)
.build();
}

@Bean
public BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient(AwsCredentialsProvider credentialsProvider,
AwsRegionProvider regionProvider) {

return BedrockRuntimeAsyncClient.builder()
.region(regionProvider.getRegion())
.credentialsProvider(credentialsProvider)
.build();
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,6 @@ public void autoConfigureBedrockConverseApi() {
var bedrockConverseApi = context.getBean(BedrockConverseApi.class);

assertThat(bedrockConverseApi).isNotNull();

assertThat(bedrockConverseApi.getRegion()).isEqualTo(Region.US_EAST_1);
});
}

Expand Down

0 comments on commit 40930c7

Please sign in to comment.