diff --git a/bench2/Benchmark/Nockma/Encoding/ByteString.hs b/bench2/Benchmark/Nockma/Encoding/ByteString.hs index f471e47b30..722ba88510 100644 --- a/bench2/Benchmark/Nockma/Encoding/ByteString.hs +++ b/bench2/Benchmark/Nockma/Encoding/ByteString.hs @@ -10,10 +10,17 @@ randomBytes numBytes = do sg <- getStdGen return (fst (genByteString numBytes sg)) +testBytesSize :: Int +testBytesSize = 250000 + bm :: Benchmark bm = bgroup "ByteString Encoding to/from integer" - [ env (randomBytes 250000) (\bs -> bench "encode bytes to integer" (nf Encoding.encodeByteString bs)), - env (Encoding.encodeByteString <$> randomBytes 250000) (\i -> bench "decode bytes from integer" (nf Encoding.decodeByteString i)) + [ env + (randomBytes testBytesSize) + (\bs -> bench "encode bytes to integer" (nf Encoding.encodeByteString bs)), + env + (Encoding.encodeByteString <$> randomBytes testBytesSize) + (\i -> bench "decode bytes from integer" (nf (Encoding.decodeByteStringWithDefault (error "failed to decode")) i)) ] diff --git a/src/Juvix/Compiler/Core/Evaluator.hs b/src/Juvix/Compiler/Core/Evaluator.hs index 0c9bff3f60..21bc538477 100644 --- a/src/Juvix/Compiler/Core/Evaluator.hs +++ b/src/Juvix/Compiler/Core/Evaluator.hs @@ -497,13 +497,20 @@ geval opts herr tab env0 = eval' env0 -- Deserialize a Integer, serialized using `serializeInteger` to a Node deserializeFromInteger :: Integer -> Node - deserializeFromInteger = deserializeNode . Encoding.decodeByteString + deserializeFromInteger = deserializeNode . decodeByteString {-# INLINE deserializeFromInteger #-} serializeToInteger :: Node -> Integer serializeToInteger = Encoding.encodeByteString . serializeNode {-# INLINE serializeToInteger #-} + decodeByteString :: Integer -> ByteString + decodeByteString = Encoding.decodeByteStringWithDefault decodeErr + where + decodeErr :: ByteString + decodeErr = err "failed to decode Integer" + {-# INLINE decodeByteString #-} + sign :: Node -> ByteString -> Node sign !messageNode !secretKeyBs = let !message = serializeNode messageNode diff --git a/src/Juvix/Compiler/Nockma/Encoding/Base.hs b/src/Juvix/Compiler/Nockma/Encoding/Base.hs index e76c7ac17f..eed92d717f 100644 --- a/src/Juvix/Compiler/Nockma/Encoding/Base.hs +++ b/src/Juvix/Compiler/Nockma/Encoding/Base.hs @@ -2,7 +2,6 @@ module Juvix.Compiler.Nockma.Encoding.Base where import Data.Bit as Bit import Data.Bits -import Data.Vector.Unboxed qualified as U import Juvix.Compiler.Nockma.Encoding.Effect.BitReader import Juvix.Compiler.Nockma.Encoding.Effect.BitWriter import Juvix.Prelude.Base @@ -12,33 +11,28 @@ import Juvix.Prelude.Base writeIntegral :: forall a r. (Integral a, Member BitWriter r) => a -> Sem r () writeIntegral x | x < 0 = error "integerToVectorBits: negative integers are not supported in this implementation" - | otherwise = unfoldBits (fromIntegral x) + | otherwise = unfoldBits 0 (fromIntegral x) where - unfoldBits :: Integer -> Sem r () - unfoldBits n - | n == 0 = return () - | otherwise = writeBit (Bit (testBit n 0)) <> unfoldBits (n `shiftR` 1) + len = bitLength x -integerToVectorBits :: (Integral a) => a -> Bit.Vector Bit -integerToVectorBits = run . execBitWriter . writeIntegral + unfoldBits :: Int -> Integer -> Sem r () + unfoldBits idx n + | idx == len = return () + | otherwise = writeBit (Bit (testBit n idx)) <> unfoldBits (idx + 1) n -- | Computes the number of bits required to store the argument in binary -- NB: 0 is encoded to the empty bit vector (as specified by the Hoon serialization spec), so 0 has bit length 0. -bitLength :: forall a. (Integral a) => a -> Int -bitLength = - length - . takeWhile (/= 0) - . iterate (`shiftR` 1) - . toInteger +bitLength :: (Integral a) => a -> Int +bitLength n + | n == 0 = 0 + | otherwise = fromIntegral (integerLog2 (abs (fromIntegral n))) + 1 --- | Decode a vector of bits (ordered from least to most significant bits) to an integer -vectorBitsToInteger :: Bit.Vector Bit -> Integer -vectorBitsToInteger = U.ifoldl' go 0 - where - go :: Integer -> Int -> Bit -> Integer - go acc idx (Bit b) - | b = setBit acc idx - | otherwise = acc +integerToVectorBits :: (Integral a) => a -> Bit.Vector Bit +integerToVectorBits = run . execBitWriter . writeIntegral + +-- | Decode a vector of bits (ordered from least to most significant bits) to a ByteString +vectorBitsToByteString :: Bit.Vector Bit -> ByteString +vectorBitsToByteString = cloneToByteString -- | Transform a Natural to an Int, computes Nothing if the Natural does not fit in an Int safeNaturalToInt :: Natural -> Maybe Int diff --git a/src/Juvix/Compiler/Nockma/Encoding/ByteString.hs b/src/Juvix/Compiler/Nockma/Encoding/ByteString.hs index f7eaf048a1..de28cd4817 100644 --- a/src/Juvix/Compiler/Nockma/Encoding/ByteString.hs +++ b/src/Juvix/Compiler/Nockma/Encoding/ByteString.hs @@ -1,9 +1,13 @@ module Juvix.Compiler.Nockma.Encoding.ByteString where +import Data.Bit (Bit) +import Data.Bit qualified as Bit import Data.Bits import Data.ByteString qualified as BS -import Data.ByteString.Builder qualified as BB import Data.ByteString.Builder qualified as BS +import Juvix.Compiler.Nockma.Encoding.Base +import Juvix.Compiler.Nockma.Encoding.Effect.BitReader +import Juvix.Compiler.Nockma.Encoding.Effect.BitWriter import Juvix.Compiler.Nockma.Language import Juvix.Prelude.Base @@ -27,6 +31,25 @@ naturalToByteString = integerToByteStringLE . toInteger byteStringToIntegerLE :: ByteString -> Integer byteStringToIntegerLE = BS.foldr (\b acc -> acc `shiftL` 8 .|. fromIntegral b) 0 +byteStringToIntegerLEChunked :: ByteString -> Integer +byteStringToIntegerLEChunked = foldr' go 0 . map (first byteStringChunkToInteger) . chunkByteString + where + chunkSize :: Int + chunkSize = 1024 + + go :: (Integer, Int) -> Integer -> Integer + go (i, size) acc = acc `shiftL` (8 * size) .|. i + + chunkByteString :: ByteString -> [(ByteString, Int)] + chunkByteString bs + | BS.null bs = [] + | otherwise = + let (chunk, rest) = BS.splitAt chunkSize bs + in (chunk, BS.length chunk) : chunkByteString rest + + byteStringChunkToInteger :: ByteString -> Integer + byteStringChunkToInteger = BS.foldr' (\b acc -> acc `shiftL` 8 .|. fromIntegral b) 0 + integerToByteStringLE :: Integer -> ByteString integerToByteStringLE = BS.toStrict . BS.toLazyByteString . go where @@ -64,114 +87,29 @@ padByteString n bs | BS.length bs >= n = bs | otherwise = BS.append bs (BS.replicate (n - BS.length bs) 0) --- | Encode an Int with a variable-length encoding --- --- The input Int is encoded in 7 bit chunks in LSB ordering. The most significant --- bit of each chunk is used to indicate when there are more bytes to read, --- 1 meaning more bytes, 0 meaning no more bytes. --- --- For example, the binary representation of 263202 is divided into 3 7-bit chunks: --- --- 263202 = 10000 0001000 0100010 --- chunk1 chunk2 chunk3 --- --- The chunks are then combined using 3 bytes in LSB ordering, with a 1 in the MSB of the first --- two bytes (indicating that another byte follows). The final byte has a 0 in the MSB bit. --- --- chunk3 chunk2 chunk1 --- 1_0100010 1_0001000 0_0010000 -encodeVarInt :: Int -> ByteString -encodeVarInt = \case - 0 -> BS.singleton 0 - n -> BS.toStrict (BB.toLazyByteString (buildVarInt n)) - where - buildVarInt :: Int -> BB.Builder - buildVarInt = \case - 0 -> mempty - i -> - let byteChunk = fromIntegral (i .&. 0x7F) -- Extract a 7-bit chunk - next = i `shiftR` 7 -- Shift to the next 7-bit chunk - currentByte = - if - | next == 0 -> byteChunk -- No more bytes, so most significant bit for this chunk is 0 - | otherwise -> byteChunk .|. 0x80 -- More bytes, so most significant bit for this chunk is 1 - in BB.word8 currentByte <> buildVarInt next - -byteStringToIntegerBE :: ByteString -> Integer -byteStringToIntegerBE = foldl' go 0 . map (first byteStringChunkToInteger) . chunkByteString - where - chunkSize :: Int - chunkSize = 1024 - - go :: Integer -> (Integer, Int) -> Integer - go acc (i, size) = acc `shiftL` (8 * size) .|. i - - chunkByteString :: ByteString -> [(ByteString, Int)] - chunkByteString bs - | BS.null bs = [] - | otherwise = - let (chunk, rest) = BS.splitAt chunkSize bs - in (chunk, BS.length chunk) : chunkByteString rest - - byteStringChunkToInteger :: ByteString -> Integer - byteStringChunkToInteger = BS.foldl' (\acc b -> acc `shiftL` 8 .|. fromIntegral b) 0 +vectorBitsToInteger :: Bit.Vector Bit -> Integer +vectorBitsToInteger = byteStringToIntegerLEChunked . vectorBitsToByteString --- | encode a ByteString to an Integer (in MSB ordering) with its length as part of the encoding. +-- | encode a ByteString to an Integer with its length as part of the encoding. encodeByteString :: ByteString -> Integer -encodeByteString bs = byteStringToIntegerBE (encodedLength <> bs) +encodeByteString = vectorBitsToInteger . run . execBitWriter . go where - encodedLength :: ByteString - encodedLength = encodeVarInt (BS.length bs) + go :: ByteString -> Sem (BitWriter ': r) () + go bs = do + let len = BS.length bs + writeLength len + writeByteString bs -- | decode a ByteString that was encoded using `encodeByteString` -decodeByteString :: Integer -> ByteString -decodeByteString n = padByteString len bytes - where - (len, bytes) = decodeVarInt (integerToBytes n) - --- | Decode an integer in MSB ordering to a bytestring. -integerToBytes :: Integer -> ByteString -integerToBytes 0 = BS.singleton 0 -integerToBytes n = BS.reverse $ BS.unfoldr go n - where - go :: Integer -> Maybe (Word8, Integer) - go = \case - 0 -> Nothing - i -> Just (fromIntegral (i .&. 0xff), i `shiftR` 8) - --- | Decode a variable-length encoded Int (using `encodeVarInt`) from the start of a ByteString. --- --- An Int is accumulated from the least significant 7-bits chunk of each byte. The --- most significant bit of each byte indicates if more bytes of the input should --- be read. If the most significant bit is one, then there are more bytes, if it --- is 0 then there are no more bytes. --- --- For example: --- --- byte1 byte2 byte3 remainder --- 1_0100010 1_0001000 0_0010000 ... --- --- The first byte has most significant bit = 1 so we accumulate the least significant 7 bits and continue. --- --- acc: 100010 --- --- The second byte has most significant bit = 1 so we accumulate and continue. The bytes are --- encoded using LSB ordering so we must shift this chunk left by 7: --- --- acc: 0001000 0100010 --- --- The third byte has most significant bit = 0 so this is the last byte. We must shift this chunk by 2 * 7 = 14 --- --- result : 10000 0001000 0100010 = 263202 -decodeVarInt :: ByteString -> (Int, ByteString) -decodeVarInt bs = go 0 0 bs +decodeByteString :: forall r. (Member (Error BitReadError) r) => Integer -> Sem r ByteString +decodeByteString i = evalBitReader (integerToVectorBits i) go where - go :: Int -> Int -> ByteString -> (Int, ByteString) - go acc toShift s = case BS.uncons s of - Nothing -> (acc, BS.empty) - Just (x, xs) -> - if - | x .&. 0x80 == 0 -> (acc .|. (fromIntegral x `shiftL` toShift), xs) -- The most significant bit is 0, no more bytes - | otherwise -> - let chunk = x .&. 0x7F -- Extract the next 7-bit chunk - in go (acc .|. (fromIntegral chunk `shiftL` toShift)) (toShift + 7) xs + go :: Sem (BitReader ': r) ByteString + go = do + len <- consumeLength + v <- consumeRemaining + return (padByteString len (Bit.cloneToByteString v)) + +-- | decode a ByteString that was encoded using `encodeByteString` with a default that's used if decoding fails. +decodeByteStringWithDefault :: ByteString -> Integer -> ByteString +decodeByteStringWithDefault d = fromRight d . run . runErrorNoCallStack @BitReadError . decodeByteString diff --git a/src/Juvix/Compiler/Nockma/Encoding/Cue.hs b/src/Juvix/Compiler/Nockma/Encoding/Cue.hs index ab9aaa69fa..ee20276373 100644 --- a/src/Juvix/Compiler/Nockma/Encoding/Cue.hs +++ b/src/Juvix/Compiler/Nockma/Encoding/Cue.hs @@ -2,6 +2,7 @@ module Juvix.Compiler.Nockma.Encoding.Cue where import Data.Bit as Bit import Juvix.Compiler.Nockma.Encoding.Base +import Juvix.Compiler.Nockma.Encoding.ByteString import Juvix.Compiler.Nockma.Encoding.Effect.BitReader import Juvix.Compiler.Nockma.Language import Juvix.Compiler.Nockma.Pretty.Base diff --git a/src/Juvix/Compiler/Nockma/Encoding/Jam.hs b/src/Juvix/Compiler/Nockma/Encoding/Jam.hs index f3ba8de81a..393b174e79 100644 --- a/src/Juvix/Compiler/Nockma/Encoding/Jam.hs +++ b/src/Juvix/Compiler/Nockma/Encoding/Jam.hs @@ -9,6 +9,7 @@ module Juvix.Compiler.Nockma.Encoding.Jam where import Data.Bit as Bit import Data.Bits import Juvix.Compiler.Nockma.Encoding.Base +import Juvix.Compiler.Nockma.Encoding.ByteString import Juvix.Compiler.Nockma.Encoding.Effect.BitWriter import Juvix.Compiler.Nockma.Language import Juvix.Prelude.Base diff --git a/test/Nockma/Encoding.hs b/test/Nockma/Encoding.hs index 07e0c20b6e..1e0e63d0dd 100644 --- a/test/Nockma/Encoding.hs +++ b/test/Nockma/Encoding.hs @@ -9,9 +9,9 @@ import Test.Tasty.Hedgehog propEncodingRoundtrip :: Property propEncodingRoundtrip = property $ do - -- The range must be greater than the chunkSize in `byteStringToIntegerBE` + -- The range must be greater than the chunkSize in `byteStringToIntegerLEChunked` bs <- forAll (Gen.bytes (Range.linear 0 3000)) - Encoding.decodeByteString (Encoding.encodeByteString bs) === bs + Encoding.decodeByteStringWithDefault (error "failed to decode") (Encoding.encodeByteString bs) === bs allTests :: TestTree allTests = testGroup "Nockma encoding" [testProperty "Roundtrip ByteArray to/from integer encoding" propEncodingRoundtrip]