Skip to content

Commit

Permalink
Merge pull request #11 from penguineer/apply-rate-limit
Browse files Browse the repository at this point in the history
Implement rate-limiting for OpenAI API calls
  • Loading branch information
penguineer authored Nov 5, 2024
2 parents e1a94c1 + 6022391 commit 3b5fb43
Show file tree
Hide file tree
Showing 5 changed files with 215 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ public static RateLimitException fromHttpResponse(HttpResponse response, Consume
String errorMessage = response.getBodyAsString().block();
String retryAfterHeader = response.getHeaders().getValue(HttpHeaderName.RETRY_AFTER);

// try header x-ratelimit-timeremaining
if (Objects.isNull(retryAfterHeader))
retryAfterHeader = response.getHeaders().getValue(HttpHeaderName.fromString("x-ratelimit-timeremaining"));

if (Objects.isNull(retryAfterHeader))
return new RateLimitException(errorMessage);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,16 @@ public class ChatRequestHandler implements ChannelAwareMessageListener {
private final ObjectMapper objectMapper;
private final AIChatService aiChatService;
private final RabbitTemplate rabbitTemplate;
private final RateLimitGate rateLimitGate;

public ChatRequestHandler(ObjectMapper objectMapper,
AIChatService aiChatService,
RabbitTemplate rabbitTemplate) {
RabbitTemplate rabbitTemplate,
RateLimitGate rateLimitGate) {
this.objectMapper = objectMapper;
this.aiChatService = aiChatService;
this.rabbitTemplate = rabbitTemplate;
this.rateLimitGate = rateLimitGate;
}


Expand Down Expand Up @@ -81,7 +84,8 @@ public void onMessage(Message message, Channel channel) {
logger.info("Reply-to header: {}", replyTo);


ChatResponse result = aiChatService.handleChatRequest(chatRequest);
ChatResponse result = rateLimitGate.callWithRateLimit(
() -> aiChatService.handleChatRequest(chatRequest));

// Convert ChatResponse to JSON
String jsonResponse = serializeChatResponse(result);
Expand All @@ -95,6 +99,11 @@ public void onMessage(Message message, Channel channel) {

// Acknowledge the message
channel.basicAck(deliveryTag, false);
} catch (InterruptedException e) {
logger.warn("Interrupted while waiting for rate limit, current message will not be acked and remains in the queue.");

// restore the interrupt flag
Thread.currentThread().interrupt();
} catch (Exception e) {
logger.info("Error on chat request", e);
Optional<String> json = serializeChatError(e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.DependsOn;

@Configuration
@EnableRabbit
Expand All @@ -22,6 +23,7 @@ public Queue chatRequestsQueue() {
}

@Bean
@DependsOn("rateLimitGate")
public SimpleMessageListenerContainer chatRequestsContainer(ConnectionFactory connectionFactory,
ChatRequestHandler handler) {
SimpleMessageListenerContainer container = new SimpleMessageListenerContainer();
Expand Down
123 changes: 123 additions & 0 deletions src/main/java/com/penguineering/hareairis/rmq/RateLimitGate.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
package com.penguineering.hareairis.rmq;

import com.penguineering.hareairis.ai.RateLimitException;
import jakarta.annotation.PreDestroy;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Component;

import java.time.Duration;
import java.time.Instant;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.Callable;
import java.util.concurrent.atomic.AtomicReference;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;

/**
* Rate limit gate to protect against rate limiting.
*
* <p>Protects against rate limiting by waiting for the next available time before executing a protected call.</p>
*/
@Component
public class RateLimitGate {
private static final Logger logger = LoggerFactory.getLogger(RateLimitGate.class);

final AtomicReference<Instant> nextAvailableTime = new AtomicReference<>(Instant.now());
final AtomicReference<Thread> waitingThread = new AtomicReference<>(null);
private final Lock threadLock = new ReentrantLock();
private final Condition threadCondition = threadLock.newCondition();

/**
* Calls the protected call with rate limiting.
*
* @param protectedCall The protected call to execute.
* @param <T> The type of the result.
* @return The result of the protected call.
* @throws Exception If the protected call throws an exception.
* @throws RateLimitException If the protected call throws a rate limit exception.
* @throws InterruptedException If the waiting thread is interrupted.
*/
public <T> T callWithRateLimit(Callable<T> protectedCall) throws Exception {
T result = null;

do
try {
result = waitAndExecute(protectedCall);
} catch (RateLimitException e) {
if (e.getRetryAfter().isEmpty())
throw e;

registerRateLimitException(e);
}
while (Objects.isNull(result));

return result;
}

/**
* Waits for the next available time and executes the protected call.
*
* @param protectedCall The protected call to execute.
* @param <T> The type of the result.
* @return The result of the protected call.
* @throws Exception If the protected call throws an exception.
* @throws InterruptedException If the waiting thread is interrupted.
*/
public <T> T waitAndExecute(Callable<T> protectedCall) throws Exception {
threadLock.lock();
try {
while (!waitingThread.compareAndSet(null, Thread.currentThread())) {
threadCondition.await();
}

waitingThread.set(Thread.currentThread());

do {
Duration waitTime = Duration.between(Instant.now(), nextAvailableTime.get());
if (waitTime.isNegative())
break;

logger.warn("Rate limit exceeded, waiting for {} seconds...", waitTime.getSeconds());
Thread.sleep(waitTime);
} while (Instant.now().isBefore(nextAvailableTime.get()));

return protectedCall.call();
} finally {
waitingThread.set(null);
threadCondition.signalAll();
threadLock.unlock();
}
}

/**
* Registers a rate-limit exception.
*
* @param e The rate-limit exception to register.
*/
public void registerRateLimitException(RateLimitException e) {
e.getRetryAfter().ifPresentOrElse(
retryAfter -> nextAvailableTime.updateAndGet(
current -> {
Instant newTargetTime = retryAfter.isAfter(current) ? retryAfter : current;
logger.info("Rate limit registered, next available time set to {}", newTargetTime);
return newTargetTime;
}),
() -> logger.warn("Tried to register a rate limit exception without a retry-after time, ignored."));
}

/**
* Interrupts the waiting thread if it is active.
*/
@PreDestroy
public void interruptWaitingThread() {
Optional.of(waitingThread)
.map(AtomicReference::get)
.ifPresent(thread -> {
logger.warn("Waiting thread is active, but application is shutting down. Interrupting...");
thread.interrupt();
});
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package com.penguineering.hareairis.rmq;

import com.penguineering.hareairis.ai.RateLimitException;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;

import java.time.Duration;
import java.time.Instant;
import java.util.concurrent.Callable;

import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.Mockito.*;

@ExtendWith(MockitoExtension.class)
class RateLimitGateTest {
private RateLimitGate rateLimitGate;

@Mock
private Callable<String> protectedCall;

@BeforeEach
void setUp() {
rateLimitGate = new RateLimitGate();
}

@Test
void testCallWithRateLimit_Success() throws Exception {
when(protectedCall.call()).thenReturn("Success");

String result = rateLimitGate.callWithRateLimit(protectedCall);

assertEquals("Success", result);
verify(protectedCall, times(1)).call();
}

@Test
void testCallWithRateLimit_RateLimitException() throws Exception {
RateLimitException rateLimitException = new RateLimitException("Rate limit exceeded", Duration.ofSeconds(1));
when(protectedCall.call()).thenThrow(rateLimitException).thenReturn("Success");

String result = rateLimitGate.callWithRateLimit(protectedCall);

assertEquals("Success", result);
verify(protectedCall, times(2)).call();
}

@Test
void testCallWithRateLimit_InterruptedException() throws Exception {
when(protectedCall.call()).thenThrow(new InterruptedException("Interrupted"));

assertThrows(InterruptedException.class, () -> rateLimitGate.callWithRateLimit(protectedCall));
verify(protectedCall, times(1)).call();
}

@Test
void testRegisterRateLimitException() {
RateLimitException rateLimitException = new RateLimitException("Rate limit exceeded", Duration.ofSeconds(1));
rateLimitGate.registerRateLimitException(rateLimitException);

assertTrue(rateLimitGate.nextAvailableTime.get().isAfter(Instant.now()));
}

@Test
void testInterruptWaitingThread() {
Thread mockThread = mock(Thread.class);
rateLimitGate.waitingThread.set(mockThread);

rateLimitGate.interruptWaitingThread();

verify(mockThread, times(1)).interrupt();
}
}

0 comments on commit 3b5fb43

Please sign in to comment.