diff --git a/src/libraries/DrawAccumulatorLib.sol b/src/libraries/DrawAccumulatorLib.sol index 4f8ef4c..da8d27c 100644 --- a/src/libraries/DrawAccumulatorLib.sol +++ b/src/libraries/DrawAccumulatorLib.sol @@ -46,12 +46,6 @@ library DrawAccumulatorLib { mapping(uint256 drawId => Observation observation) observations; } - /// @notice A pair of uint24s. - struct Pair48 { - uint24 first; - uint24 second; - } - /// @notice Adds balance for the given draw id to the accumulator. /// @param accumulator The accumulator to add to /// @param _amount The amount of balance to add @@ -103,10 +97,7 @@ library DrawAccumulatorLib { return true; } else { - accumulatorObservations[newestDrawId_] = Observation({ - available: SafeCast.toUint96(newestObservation_.available + _amount), - disbursed: newestObservation_.disbursed - }); + accumulatorObservations[newestDrawId_].available = SafeCast.toUint96(newestObservation_.available + _amount); return false; } @@ -161,56 +152,48 @@ library DrawAccumulatorLib { return 0; } - uint24 firstObservationDrawIdOccurringAtOrAfterStart; - if (_startDrawId <= oldestDrawId || ringBufferInfo.cardinality == 1) { - firstObservationDrawIdOccurringAtOrAfterStart = oldestDrawId; - } else { - // The start must be between newest and oldest - uint24 beforeOrAtDrawId; - // binary search - ( - , - beforeOrAtDrawId, - , - firstObservationDrawIdOccurringAtOrAfterStart - ) = binarySearch( - _accumulator.drawRingBuffer, - uint16(oldestIndex), - uint16(newestIndex), - ringBufferInfo.cardinality, - _startDrawId - ); - if (beforeOrAtDrawId == _startDrawId) { - firstObservationDrawIdOccurringAtOrAfterStart = _startDrawId; + // check if the start draw has an observation, otherwise search for the earliest observation after + Observation memory atOrAfterStart = _accumulator.observations[_startDrawId]; + if (atOrAfterStart.available == 0 && atOrAfterStart.disbursed == 0) { + if (_startDrawId <= oldestDrawId || ringBufferInfo.cardinality == 1) { + atOrAfterStart = _accumulator.observations[oldestDrawId]; + } else { + (, uint24 beforeOrAtDrawId, , uint24 afterOrAtDrawId) = binarySearch( + _accumulator.drawRingBuffer, + oldestIndex, + newestIndex, + ringBufferInfo.cardinality, + _startDrawId + ); + if (beforeOrAtDrawId == _startDrawId) { + atOrAfterStart = _accumulator.observations[_startDrawId]; + } else { + atOrAfterStart = _accumulator.observations[afterOrAtDrawId]; + } } } - uint24 lastObservationDrawIdOccurringAtOrBeforeEnd; - if (_endDrawId >= newestDrawId || ringBufferInfo.cardinality == 1) { - // then it must be the end - lastObservationDrawIdOccurringAtOrBeforeEnd = newestDrawId; - } else { - uint24 afterOrAtDrawId; - (, lastObservationDrawIdOccurringAtOrBeforeEnd, ,afterOrAtDrawId) = binarySearch( - _accumulator.drawRingBuffer, - uint16(oldestIndex), - uint16(newestIndex), - ringBufferInfo.cardinality, - _endDrawId - ); - if (afterOrAtDrawId == _endDrawId) { - lastObservationDrawIdOccurringAtOrBeforeEnd = _endDrawId; + // check if the end draw has an observation, otherwise search for the latest observation before + Observation memory atOrBeforeEnd = _accumulator.observations[_endDrawId]; + if (atOrBeforeEnd.available == 0 && atOrBeforeEnd.disbursed == 0) { + if (_endDrawId >= newestDrawId || ringBufferInfo.cardinality == 1) { + atOrBeforeEnd = _accumulator.observations[newestDrawId]; + } else { + (, uint24 beforeOrAtDrawId, , uint24 afterOrAtDrawId) = binarySearch( + _accumulator.drawRingBuffer, + oldestIndex, + newestIndex, + ringBufferInfo.cardinality, + _endDrawId + ); + if (afterOrAtDrawId == _endDrawId) { + atOrBeforeEnd = _accumulator.observations[_endDrawId]; + } else { + atOrBeforeEnd = _accumulator.observations[beforeOrAtDrawId]; + } } } - Observation memory atOrAfterStart = _accumulator.observations[ - firstObservationDrawIdOccurringAtOrAfterStart - ]; - - Observation memory atOrBeforeEnd = _accumulator.observations[ - lastObservationDrawIdOccurringAtOrBeforeEnd - ]; - return atOrBeforeEnd.available + atOrBeforeEnd.disbursed - atOrAfterStart.disbursed; }