Skip to content

Commit

Permalink
Fix: watsonx.ai | refresh token issue and implement retry mechanism (#…
Browse files Browse the repository at this point in the history
…735)

* fix: refresh token only when needed and apply retry policy

* fix: remove unused import
  • Loading branch information
PabloSanchi authored May 19, 2024
1 parent 09e122d commit e8f663d
Showing 1 changed file with 15 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,11 @@
import java.util.function.Consumer;

import com.ibm.cloud.sdk.core.security.IamAuthenticator;
import com.ibm.cloud.sdk.core.security.IamToken;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.retry.annotation.Backoff;
import org.springframework.retry.annotation.Retryable;
import reactor.core.publisher.Flux;

import org.springframework.ai.retry.RetryUtils;
Expand Down Expand Up @@ -50,6 +53,7 @@ public class WatsonxAiApi {
private final String streamEndpoint;
private final String textEndpoint;
private final String projectId;
private IamToken token;

/**
* Create a new chat api.
Expand All @@ -72,6 +76,7 @@ public WatsonxAiApi(
this.textEndpoint = textEndpoint;
this.projectId = projectId;
this.iamAuthenticator = IamAuthenticator.fromConfiguration(Map.of("APIKEY", IAMToken));
this.token = this.iamAuthenticator.requestToken();

Consumer<HttpHeaders> defaultHeaders = headers -> {
headers.setContentType(MediaType.APPLICATION_JSON);
Expand All @@ -88,27 +93,33 @@ public WatsonxAiApi(
.build();
}

@Retryable(retryFor = Exception.class, maxAttempts = 3, backoff = @Backoff(random = true, delay = 1200, maxDelay = 7000, multiplier = 2.5))
public ResponseEntity<WatsonxAiResponse> generate(WatsonxAiRequest watsonxAiRequest) {
Assert.notNull(watsonxAiRequest, WATSONX_REQUEST_CANNOT_BE_NULL);

String bearer = this.iamAuthenticator.requestToken().getAccessToken();
if(this.token.needsRefresh()) {
this.token = this.iamAuthenticator.requestToken();
}

return this.restClient.post()
.uri(this.textEndpoint)
.header(HttpHeaders.AUTHORIZATION, "Bearer " + bearer)
.header(HttpHeaders.AUTHORIZATION, "Bearer " + this.token.getAccessToken())
.body(watsonxAiRequest.withProjectId(projectId))
.retrieve()
.toEntity(WatsonxAiResponse.class);
}

@Retryable(retryFor = Exception.class, maxAttempts = 3, backoff = @Backoff(random = true, delay = 1200, maxDelay = 7000, multiplier = 2.5))
public Flux<WatsonxAiResponse> generateStreaming(WatsonxAiRequest watsonxAiRequest) {
Assert.notNull(watsonxAiRequest, WATSONX_REQUEST_CANNOT_BE_NULL);

String bearer = this.iamAuthenticator.requestToken().getAccessToken();
if(this.token.needsRefresh()) {
this.token = this.iamAuthenticator.requestToken();
}

return this.webClient.post()
.uri(this.streamEndpoint)
.header(HttpHeaders.AUTHORIZATION, "Bearer " + bearer)
.header(HttpHeaders.AUTHORIZATION, "Bearer " + this.token.getAccessToken())
.bodyValue(watsonxAiRequest.withProjectId(this.projectId))
.retrieve()
.bodyToFlux(WatsonxAiResponse.class)
Expand Down

0 comments on commit e8f663d

Please sign in to comment.