Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 70 additions & 40 deletions cpc/cpc_compressed_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package cpc

import (
"errors"
"fmt"
"math/bits"

Expand Down Expand Up @@ -271,7 +272,10 @@ func (c *CpcCompressedState) compressHybridFlavor(src *CpcSketch) error {
return fmt.Errorf("compressHybridFlavor: invariant violation (%d + %d != %d)",
numPairsFromArray, srcNumPairs, srcNumCoupons)
}
allPairs := trickyGetPairsFromWindow(srcSlidingWindow, srcK, numPairsFromArray, srcNumPairs)
allPairs, err := trickyGetPairsFromWindow(srcSlidingWindow, srcK, numPairsFromArray, srcNumPairs)
if err != nil {
return err
}
mergePairs(srcPairArr, 0, srcNumPairs, allPairs, srcNumPairs, numPairsFromArray, allPairs, 0)
return compressTheSurprisingValues(c, src, allPairs, int(srcNumCoupons))
}
Expand Down Expand Up @@ -428,7 +432,10 @@ func (c *CpcCompressedState) compressSlidingFlavor(src *CpcSketch) error {
}

// Apply a transformation to the column indices.
pseudoPhase := determinePseudoPhase(src.lgK, int64(src.numCoupons))
pseudoPhase, err := determinePseudoPhase(src.lgK, int64(src.numCoupons))
if err != nil {
return err
}
if pseudoPhase >= 16 {
return fmt.Errorf("compressSlidingFlavor: pseudoPhase (%d) >= 16", pseudoPhase)
}
Expand Down Expand Up @@ -492,7 +499,10 @@ func (c *CpcCompressedState) uncompressSlidingFlavor(src *CpcSketch) error {
}

// Determine pseudoPhase.
pseudoPhase := determinePseudoPhase(srcLgK, int64(c.NumCoupons))
pseudoPhase, err := determinePseudoPhase(srcLgK, int64(c.NumCoupons))
if err != nil {
return err
}
if pseudoPhase >= 16 {
return fmt.Errorf("uncompressSlidingFlavor: pseudoPhase %d out of range", pseudoPhase)
}
Expand Down Expand Up @@ -608,7 +618,7 @@ func importFromMemory(bytes []byte) (*CpcCompressedState, error) {
state.CwStream = getWStream(bytes)
state.CsvStream = getSvStream(bytes)
default:
panic("not implemented")
return nil, fmt.Errorf("unknown format: %d", format)
}
return state, nil
}
Expand Down Expand Up @@ -675,14 +685,23 @@ func compressTheSurprisingValues(target *CpcCompressedState, source *CpcSketch,
// Compute srcK = 1 << source.lgK.
srcK := 1 << source.lgK
// Determine the number of base bits using a Golomb code decision.
numBaseBits := golombChooseNumberOfBaseBits(srcK+numPairs, numPairs)
numBaseBits, err := golombChooseNumberOfBaseBits(srcK+numPairs, numPairs)
if err != nil {
return err
}
// Compute an upper-bound length for the compressed pairs buffer.
pairBufLen := safeLengthForCompressedPairBuf(srcK, numPairs, numBaseBits)
pairBufLen, err := safeLengthForCompressedPairBuf(srcK, numPairs, numBaseBits)
if err != nil {
return err
}
// Allocate the buffer for compression.
pairBuf := make([]int, pairBufLen)
// lowLevelCompressPairs compresses 'pairs' using the chosen base bits into pairBuf.
// It returns the number of ints that represent the compressed data.
csvLength := lowLevelCompressPairs(pairs, numPairs, numBaseBits, pairBuf)
csvLength, err := lowLevelCompressPairs(pairs, numPairs, numBaseBits, pairBuf)
if err != nil {
return err
}
target.CsvLengthInts = csvLength
target.CsvStream = pairBuf
return nil
Expand All @@ -696,32 +715,35 @@ func uncompressTheSurprisingValues(source *CpcCompressedState) ([]int, error) {
}
pairs := make([]int, numPairs)
// Determine the number of base bits using the Golomb code decision.
numBaseBits := golombChooseNumberOfBaseBits(srcK+numPairs, numPairs)
numBaseBits, err := golombChooseNumberOfBaseBits(srcK+numPairs, numPairs)
if err != nil {
return nil, err
}
// lowLevelUncompressPairs fills the 'pairs' slice using the compressed CSV stream.
if err := lowLevelUncompressPairs(pairs, numPairs, numBaseBits, source.CsvStream, source.CsvLengthInts); err != nil {
return nil, err
}
return pairs, nil
}

func golombChooseNumberOfBaseBits(k, count int) int {
func golombChooseNumberOfBaseBits(k, count int) (int, error) {
if k < 1 || count < 1 {
panic("golombChooseNumberOfBaseBits: k and count must be >= 1")
return 0, errors.New("golombChooseNumberOfBaseBits: k and count must be >= 1")
}
quotient := (k - count) / count
if quotient == 0 {
return 0
return 0, nil
}
return floorLog2(uint64(quotient))
return floorLog2(uint64(quotient)), nil
}

func floorLog2(x uint64) int {
return bits.Len64(x) - 1
}

func safeLengthForCompressedPairBuf(k, numPairs, numBaseBits int) int {
func safeLengthForCompressedPairBuf(k, numPairs, numBaseBits int) (int, error) {
if numPairs <= 0 {
panic("safeLengthForCompressedPairBuf: numPairs must be > 0")
return 0, errors.New("safeLengthForCompressedPairBuf: numPairs must be > 0")
}
// Compute ybits = (numPairs * (1 + numBaseBits)) + (k >>> numBaseBits)
ybits := int64(numPairs)*(1+int64(numBaseBits)) + (int64(k) >> uint(numBaseBits))
Expand All @@ -736,9 +758,9 @@ func safeLengthForCompressedPairBuf(k, numPairs, numBaseBits int) int {
words := divideBy32RoundingUp(totalBits)
// Ensure the number of words fits in a 31-bit int.
if words >= (1 << 31) {
panic("safeLengthForCompressedPairBuf: words too large")
return 0, errors.New("safeLengthForCompressedPairBuf: words too large")
}
return int(words)
return int(words), nil
}

func divideBy32RoundingUp(x int64) int64 {
Expand All @@ -749,7 +771,7 @@ func divideBy32RoundingUp(x int64) int64 {
return tmp + 1
}

func lowLevelCompressPairs(pairArray []int, numPairsToEncode, numBaseBits int, compressedWords []int) int {
func lowLevelCompressPairs(pairArray []int, numPairsToEncode, numBaseBits int, compressedWords []int) (int, error) {
nextWordIndex := 0
var bitBuf uint64 = 0
bufBits := 0
Expand All @@ -773,8 +795,7 @@ func lowLevelCompressPairs(pairArray []int, numPairsToEncode, numBaseBits int, c
predictedColIndex = 0
}
if rowIndex < predictedRowIndex || colIndex < predictedColIndex {
panic(fmt.Sprintf("lowLevelCompressPairs: assertion failed: rowIndex=%d, predictedRowIndex=%d, colIndex=%d, predictedColIndex=%d",
rowIndex, predictedRowIndex, colIndex, predictedColIndex))
return 0, fmt.Errorf("lowLevelCompressPairs: assertion failed: rowIndex=%d, predictedRowIndex=%d, colIndex=%d, predictedColIndex=%d", rowIndex, predictedRowIndex, colIndex, predictedColIndex)
}

// yDelta is the difference in row indices.
Expand Down Expand Up @@ -846,7 +867,7 @@ func lowLevelCompressPairs(pairArray []int, numPairsToEncode, numBaseBits int, c
compressedWords[nextWordIndex] = int(bitBuf & 0xFFFFFFFF)
nextWordIndex++
}
return nextWordIndex
return nextWordIndex, nil
}

func lowLevelUncompressPairs(pairArray []int, numPairsToDecode, numBaseBits int, compressedWords []int, numCompressedWords int) error {
Expand Down Expand Up @@ -889,7 +910,10 @@ func lowLevelUncompressPairs(pairArray []int, numPairsToDecode, numBaseBits int,
ptrArr[NextWordIdx] = int64(nextWordIndex)
ptrArr[BitBuf] = int64(bitBuf)
ptrArr[BufBits] = int64(bufBits)
golombHi := readUnary(compressedWords, ptrArr)
golombHi, err := readUnary(compressedWords, ptrArr)
if err != nil {
return err
}
// Retrieve updated values.
nextWordIndex = int(ptrArr[NextWordIdx])
bitBuf = uint64(ptrArr[BitBuf])
Expand Down Expand Up @@ -931,7 +955,7 @@ func lowLevelUncompressPairs(pairArray []int, numPairsToDecode, numBaseBits int,
return nil
}

func readUnary(compressedWords []int, ptrArr []int64) int64 {
func readUnary(compressedWords []int, ptrArr []int64) (int64, error) {
nextWordIndex := int(ptrArr[NextWordIdx])
bitBuf := uint64(ptrArr[BitBuf])
bufBits := int(ptrArr[BufBits])
Expand All @@ -944,7 +968,7 @@ func readUnary(compressedWords []int, ptrArr []int64) int64 {
// Ensure we have at least 8 bits in the bit buffer.
if bufBits < 8 {
if nextWordIndex >= len(compressedWords) {
panic("readUnary: insufficient compressedWords data")
return 0, errors.New("readUnary: insufficient compressedWords data")
}
bitBuf |= (uint64(compressedWords[nextWordIndex]) & 0xFFFFFFFF) << uint(bufBits)
nextWordIndex++
Expand Down Expand Up @@ -975,7 +999,7 @@ func readUnary(compressedWords []int, ptrArr []int64) int64 {
ptrArr[BitBuf] = int64(bitBuf)
ptrArr[BufBits] = int64(bufBits)

return subTotal + int64(trailingZeros)
return subTotal + int64(trailingZeros), nil
}

func writeUnary(compressedWords []int, ptrArr []int64, theValue int) {
Expand Down Expand Up @@ -1011,7 +1035,7 @@ func writeUnary(compressedWords []int, ptrArr []int64, theValue int) {
ptrArr[BufBits] = int64(bufBits)
}

func trickyGetPairsFromWindow(window []byte, k, numPairsToGet, emptySpace int) []int {
func trickyGetPairsFromWindow(window []byte, k, numPairsToGet, emptySpace int) ([]int, error) {
outputLength := emptySpace + numPairsToGet
pairs := make([]int, outputLength)
pairIndex := emptySpace
Expand All @@ -1031,10 +1055,10 @@ func trickyGetPairsFromWindow(window []byte, k, numPairsToGet, emptySpace int) [
}

if pairIndex != outputLength {
panic(fmt.Sprintf("trickyGetPairsFromWindow: pairIndex (%d) != outputLength (%d)", pairIndex, outputLength))
return nil, fmt.Errorf("trickyGetPairsFromWindow: pairIndex (%d) != outputLength (%d)", pairIndex, outputLength)
}

return pairs
return pairs, nil
}

func (c *CpcCompressedState) compressTheWindow(src *CpcSketch) error {
Expand All @@ -1045,7 +1069,10 @@ func (c *CpcCompressedState) compressTheWindow(src *CpcSketch) error {
windowBufLen := safeLengthForCompressedWindowBuf(int64(srcK))
windowBuf := make([]int, windowBufLen)
// Determine the pseudo-phase using srcLgK and the number of coupons.
pseudoPhase := determinePseudoPhase(srcLgK, int64(src.numCoupons))
pseudoPhase, err := determinePseudoPhase(srcLgK, int64(src.numCoupons))
if err != nil {
return err
}
// Compress the sliding window bytes.
// lowLevelCompressBytes is assumed to return (cwLengthInts int, err error).
cwLengthInts := lowLevelCompressBytes(src.slidingWindow, srcK, encodingTablesForHighEntropyByte[pseudoPhase], windowBuf)
Expand All @@ -1069,7 +1096,10 @@ func uncompressTheWindow(target *CpcSketch, source *CpcCompressedState) error {
target.slidingWindow = window

// Determine the pseudo-phase using srcLgK and source.NumCoupons.
pseudoPhase := determinePseudoPhase(srcLgK, int64(source.NumCoupons))
pseudoPhase, err := determinePseudoPhase(srcLgK, int64(source.NumCoupons))
if err != nil {
return err
}
// Ensure that source.CwStream is not nil.
if source.CwStream == nil {
return fmt.Errorf("uncompressTheWindow: source.CwStream is nil")
Expand All @@ -1091,37 +1121,37 @@ func safeLengthForCompressedWindowBuf(k int64) int {
return int(divideBy32RoundingUp(totalBits))
}

func determinePseudoPhase(lgK int, numCoupons int64) int {
func determinePseudoPhase(lgK int, numCoupons int64) (int, error) {
k := int64(1) << uint(lgK)
c := numCoupons
// Midrange logic.
if (1000 * c) < (2375 * k) {
if (4 * c) < (3 * k) {
return 16 + 0
return 16 + 0, nil
} else if (10 * c) < (11 * k) {
return 16 + 1
return 16 + 1, nil
} else if (100 * c) < (132 * k) {
return 16 + 2
return 16 + 2, nil
} else if (3 * c) < (5 * k) {
return 16 + 3
return 16 + 3, nil
} else if (1000 * c) < (1965 * k) {
return 16 + 4
return 16 + 4, nil
} else if (1000 * c) < (2275 * k) {
return 16 + 5
return 16 + 5, nil
} else {
return 6 // steady-state table employed before its actual phase.
return 6, nil // steady-state table employed before its actual phase.
}
} else {
// Steady-state logic.
if lgK < 4 {
panic("determinePseudoPhase: lgK must be at least 4")
return 0, errors.New("determinePseudoPhase: lgK must be at least 4")
}
tmp := c >> uint(lgK-4)
phase := int(tmp & 15)
if phase < 0 || phase >= 16 {
panic(fmt.Sprintf("determinePseudoPhase: phase out of range: %d", phase))
return 0, fmt.Errorf("determinePseudoPhase: phase out of range: %d", phase)
}
return phase
return phase, nil
}
}

Expand Down
15 changes: 10 additions & 5 deletions cpc/cpc_compressed_state_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ import (
"sort"
"testing"

"github.com/stretchr/testify/assert"

"github.com/apache/datasketches-go/internal"
)

Expand Down Expand Up @@ -84,7 +86,8 @@ func TestWriteReadUnary(t *testing.T) {
if nextWordIndex != int(ptrArr[NextWordIdx]) {
t.Errorf("Before readUnary: nextWordIndex %d != ptrArr[NextWordIdx] %d", nextWordIndex, ptrArr[NextWordIdx])
}
result := readUnary(compressedWords, ptrArr)
result, err := readUnary(compressedWords, ptrArr)
assert.NoError(t, err)
t.Logf("Result: %d, expected: %d", result, i)
if result != int64(i) {
t.Errorf("Mismatch: got %d, expected %d", result, i)
Expand Down Expand Up @@ -170,9 +173,10 @@ func TestWriteReadPairs(t *testing.T) {
compressedWords := make([]int, MaxWords)
// Loop over base bits 0 to 11.
for bb := 0; bb <= 11; bb++ {
numWordsWritten := lowLevelCompressPairs(pairArray, numPairs, bb, compressedWords)
numWordsWritten, err := lowLevelCompressPairs(pairArray, numPairs, bb, compressedWords)
assert.NoError(t, err)
t.Logf("numWordsWritten = %d, bb = %d", numWordsWritten, bb)
err := lowLevelUncompressPairs(pairArray2, numPairs, bb, compressedWords, numWordsWritten)
err = lowLevelUncompressPairs(pairArray2, numPairs, bb, compressedWords, numWordsWritten)
if err != nil {
t.Errorf("Error in lowLevelUncompressPairs for bb=%d: %v", bb, err)
}
Expand Down Expand Up @@ -390,9 +394,10 @@ func TestWriteReadPairsExtended(t *testing.T) {
compressedWords := make([]int, MaxWords)
// Loop over base bits 0 to 11.
for bb := 0; bb <= 11; bb++ {
numWordsWritten := lowLevelCompressPairs(pairArray, numPairs, bb, compressedWords)
numWordsWritten, err := lowLevelCompressPairs(pairArray, numPairs, bb, compressedWords)
assert.NoError(t, err)
t.Logf("Base bits: %d, words written: %d", bb, numWordsWritten)
err := lowLevelUncompressPairs(pairArray2, numPairs, bb, compressedWords, numWordsWritten)
err = lowLevelUncompressPairs(pairArray2, numPairs, bb, compressedWords, numWordsWritten)
if err != nil {
t.Errorf("Error in lowLevelUncompressPairs for base bits %d: %v", bb, err)
}
Expand Down
Loading