-
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #11 from penguineer/apply-rate-limit
Implement rate-limiting for OpenAI API calls
- Loading branch information
Showing
5 changed files
with
215 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
123 changes: 123 additions & 0 deletions
123
src/main/java/com/penguineering/hareairis/rmq/RateLimitGate.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
package com.penguineering.hareairis.rmq; | ||
|
||
import com.penguineering.hareairis.ai.RateLimitException; | ||
import jakarta.annotation.PreDestroy; | ||
import org.slf4j.Logger; | ||
import org.slf4j.LoggerFactory; | ||
import org.springframework.stereotype.Component; | ||
|
||
import java.time.Duration; | ||
import java.time.Instant; | ||
import java.util.Objects; | ||
import java.util.Optional; | ||
import java.util.concurrent.Callable; | ||
import java.util.concurrent.atomic.AtomicReference; | ||
import java.util.concurrent.locks.Condition; | ||
import java.util.concurrent.locks.Lock; | ||
import java.util.concurrent.locks.ReentrantLock; | ||
|
||
/** | ||
* Rate limit gate to protect against rate limiting. | ||
* | ||
* <p>Protects against rate limiting by waiting for the next available time before executing a protected call.</p> | ||
*/ | ||
@Component | ||
public class RateLimitGate { | ||
private static final Logger logger = LoggerFactory.getLogger(RateLimitGate.class); | ||
|
||
final AtomicReference<Instant> nextAvailableTime = new AtomicReference<>(Instant.now()); | ||
final AtomicReference<Thread> waitingThread = new AtomicReference<>(null); | ||
private final Lock threadLock = new ReentrantLock(); | ||
private final Condition threadCondition = threadLock.newCondition(); | ||
|
||
/** | ||
* Calls the protected call with rate limiting. | ||
* | ||
* @param protectedCall The protected call to execute. | ||
* @param <T> The type of the result. | ||
* @return The result of the protected call. | ||
* @throws Exception If the protected call throws an exception. | ||
* @throws RateLimitException If the protected call throws a rate limit exception. | ||
* @throws InterruptedException If the waiting thread is interrupted. | ||
*/ | ||
public <T> T callWithRateLimit(Callable<T> protectedCall) throws Exception { | ||
T result = null; | ||
|
||
do | ||
try { | ||
result = waitAndExecute(protectedCall); | ||
} catch (RateLimitException e) { | ||
if (e.getRetryAfter().isEmpty()) | ||
throw e; | ||
|
||
registerRateLimitException(e); | ||
} | ||
while (Objects.isNull(result)); | ||
|
||
return result; | ||
} | ||
|
||
/** | ||
* Waits for the next available time and executes the protected call. | ||
* | ||
* @param protectedCall The protected call to execute. | ||
* @param <T> The type of the result. | ||
* @return The result of the protected call. | ||
* @throws Exception If the protected call throws an exception. | ||
* @throws InterruptedException If the waiting thread is interrupted. | ||
*/ | ||
public <T> T waitAndExecute(Callable<T> protectedCall) throws Exception { | ||
threadLock.lock(); | ||
try { | ||
while (!waitingThread.compareAndSet(null, Thread.currentThread())) { | ||
threadCondition.await(); | ||
} | ||
|
||
waitingThread.set(Thread.currentThread()); | ||
|
||
do { | ||
Duration waitTime = Duration.between(Instant.now(), nextAvailableTime.get()); | ||
if (waitTime.isNegative()) | ||
break; | ||
|
||
logger.warn("Rate limit exceeded, waiting for {} seconds...", waitTime.getSeconds()); | ||
Thread.sleep(waitTime); | ||
} while (Instant.now().isBefore(nextAvailableTime.get())); | ||
|
||
return protectedCall.call(); | ||
} finally { | ||
waitingThread.set(null); | ||
threadCondition.signalAll(); | ||
threadLock.unlock(); | ||
} | ||
} | ||
|
||
/** | ||
* Registers a rate-limit exception. | ||
* | ||
* @param e The rate-limit exception to register. | ||
*/ | ||
public void registerRateLimitException(RateLimitException e) { | ||
e.getRetryAfter().ifPresentOrElse( | ||
retryAfter -> nextAvailableTime.updateAndGet( | ||
current -> { | ||
Instant newTargetTime = retryAfter.isAfter(current) ? retryAfter : current; | ||
logger.info("Rate limit registered, next available time set to {}", newTargetTime); | ||
return newTargetTime; | ||
}), | ||
() -> logger.warn("Tried to register a rate limit exception without a retry-after time, ignored.")); | ||
} | ||
|
||
/** | ||
* Interrupts the waiting thread if it is active. | ||
*/ | ||
@PreDestroy | ||
public void interruptWaitingThread() { | ||
Optional.of(waitingThread) | ||
.map(AtomicReference::get) | ||
.ifPresent(thread -> { | ||
logger.warn("Waiting thread is active, but application is shutting down. Interrupting..."); | ||
thread.interrupt(); | ||
}); | ||
} | ||
} |
75 changes: 75 additions & 0 deletions
75
src/test/java/com/penguineering/hareairis/rmq/RateLimitGateTest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
package com.penguineering.hareairis.rmq; | ||
|
||
import com.penguineering.hareairis.ai.RateLimitException; | ||
import org.junit.jupiter.api.BeforeEach; | ||
import org.junit.jupiter.api.Test; | ||
import org.junit.jupiter.api.extension.ExtendWith; | ||
import org.mockito.Mock; | ||
import org.mockito.junit.jupiter.MockitoExtension; | ||
|
||
import java.time.Duration; | ||
import java.time.Instant; | ||
import java.util.concurrent.Callable; | ||
|
||
import static org.junit.jupiter.api.Assertions.*; | ||
import static org.mockito.Mockito.*; | ||
|
||
@ExtendWith(MockitoExtension.class) | ||
class RateLimitGateTest { | ||
private RateLimitGate rateLimitGate; | ||
|
||
@Mock | ||
private Callable<String> protectedCall; | ||
|
||
@BeforeEach | ||
void setUp() { | ||
rateLimitGate = new RateLimitGate(); | ||
} | ||
|
||
@Test | ||
void testCallWithRateLimit_Success() throws Exception { | ||
when(protectedCall.call()).thenReturn("Success"); | ||
|
||
String result = rateLimitGate.callWithRateLimit(protectedCall); | ||
|
||
assertEquals("Success", result); | ||
verify(protectedCall, times(1)).call(); | ||
} | ||
|
||
@Test | ||
void testCallWithRateLimit_RateLimitException() throws Exception { | ||
RateLimitException rateLimitException = new RateLimitException("Rate limit exceeded", Duration.ofSeconds(1)); | ||
when(protectedCall.call()).thenThrow(rateLimitException).thenReturn("Success"); | ||
|
||
String result = rateLimitGate.callWithRateLimit(protectedCall); | ||
|
||
assertEquals("Success", result); | ||
verify(protectedCall, times(2)).call(); | ||
} | ||
|
||
@Test | ||
void testCallWithRateLimit_InterruptedException() throws Exception { | ||
when(protectedCall.call()).thenThrow(new InterruptedException("Interrupted")); | ||
|
||
assertThrows(InterruptedException.class, () -> rateLimitGate.callWithRateLimit(protectedCall)); | ||
verify(protectedCall, times(1)).call(); | ||
} | ||
|
||
@Test | ||
void testRegisterRateLimitException() { | ||
RateLimitException rateLimitException = new RateLimitException("Rate limit exceeded", Duration.ofSeconds(1)); | ||
rateLimitGate.registerRateLimitException(rateLimitException); | ||
|
||
assertTrue(rateLimitGate.nextAvailableTime.get().isAfter(Instant.now())); | ||
} | ||
|
||
@Test | ||
void testInterruptWaitingThread() { | ||
Thread mockThread = mock(Thread.class); | ||
rateLimitGate.waitingThread.set(mockThread); | ||
|
||
rateLimitGate.interruptWaitingThread(); | ||
|
||
verify(mockThread, times(1)).interrupt(); | ||
} | ||
} |