diff --git a/internal/irzstd/disk.go b/internal/irzstd/disk.go index 29bac24..92d6892 100644 --- a/internal/irzstd/disk.go +++ b/internal/irzstd/disk.go @@ -41,6 +41,7 @@ type diskWriter struct { irWriter *ir.Writer irTotalBytes int zstdWriter *zstd.Encoder + state WriterState } // Opens a new [diskWriter] using files for IR and Zstd buffers. For use when use_disk_store @@ -70,6 +71,7 @@ func NewDiskWriter(irPath string, zstdPath string) (*diskWriter, error) { zstdPath: zstdPath, zstdFile: zstdFile, zstdWriter: zstdWriter, + state: Open, } return &diskWriter, nil @@ -106,6 +108,7 @@ func RecoverWriter(irPath string, zstdPath string) (*diskWriter, error) { zstdPath: zstdPath, zstdFile: zstdFile, zstdWriter: zstdWriter, + state: Open, } irFileSize, err := diskWriter.getIrFileSize() @@ -141,6 +144,10 @@ func RecoverWriter(irPath string, zstdPath string) (*diskWriter, error) { // - numEvents: Number of log events successfully written to IR writer buffer // - err: Error writing IR/Zstd, error flushing buffers func (w *diskWriter) WriteIrZstd(logEvents []ffi.LogEvent) (int, error) { + if w.state != Open { + return 0, fmt.Errorf("cannot write: writer state is %s, expected %s", w.state, Open) + } + if w.irWriter == nil { var err error w.irWriter, err = ir.NewWriter[ir.FourByteEncoding](w.irFile) @@ -178,6 +185,13 @@ func (w *diskWriter) WriteIrZstd(logEvents []ffi.LogEvent) (int, error) { // Returns: // - err: Error flushing/closing buffers func (w *diskWriter) CloseStreams() error { + if w.state == StreamsClosed { + return nil + } + if w.state != Open { + return fmt.Errorf("cannot close streams: writer state is %s, expected %s", w.state, Open) + } + // IR buffer may not be empty, so must be flushed prior to adding trailing EndOfStream byte. err := w.flushIrBuffer() if err != nil { @@ -195,14 +209,17 @@ func (w *diskWriter) CloseStreams() error { w.zstdWriter.Write([]byte{irEndOfStreamByte}) err = w.zstdWriter.Close() if err != nil { + w.state = Corrupted return err } _, err = w.zstdFile.Seek(0, io.SeekStart) if err != nil { + w.state = Corrupted return err } + w.state = StreamsClosed return nil } @@ -212,6 +229,10 @@ func (w *diskWriter) CloseStreams() error { // Returns: // - err: Error IR buffer not empty func (w *diskWriter) Reset() error { + if w.state != StreamsClosed { + return fmt.Errorf("cannot reset: writer state is %s, expected %s", w.state, StreamsClosed) + } + // Flush should be called prior to reset, so buffer should be empty. There may be a future // use case to truncate a non-empty IR buffer; however, there is currently no use case // so safer to throw an error. @@ -221,16 +242,19 @@ func (w *diskWriter) Reset() error { _, err := w.zstdFile.Seek(0, io.SeekStart) if err != nil { + w.state = Corrupted return err } err = w.zstdFile.Truncate(0) if err != nil { + w.state = Corrupted return err } w.zstdWriter.Reset(w.zstdFile) + w.state = Open return nil } @@ -263,6 +287,14 @@ func (w *diskWriter) Close() error { return nil } +// Getter for state. +// +// Returns: +// - state: Current state +func (w *diskWriter) GetState() WriterState { + return w.state +} + // Getter for Zstd Output. // // Returns: @@ -304,7 +336,6 @@ func (w *diskWriter) flushIrBuffer() error { return fmt.Errorf("error flush called with non-existent buffer") } - // Flush is called during Close(), and possible that the IR buffer is empty. if w.irTotalBytes == 0 { return nil } @@ -313,16 +344,19 @@ func (w *diskWriter) flushIrBuffer() error { _, err := w.irFile.Seek(0, io.SeekStart) if err != nil { + w.state = Corrupted return err } _, err = io.Copy(w.zstdWriter, w.irFile) if err != nil { + w.state = Corrupted return err } err = w.zstdWriter.Close() if err != nil { + w.state = Corrupted return err } @@ -332,11 +366,13 @@ func (w *diskWriter) flushIrBuffer() error { _, err = w.irFile.Seek(0, io.SeekStart) if err != nil { + w.state = Corrupted return err } err = w.irFile.Truncate(0) if err != nil { + w.state = Corrupted return err } diff --git a/internal/irzstd/memory.go b/internal/irzstd/memory.go index 1f444fd..c221e1a 100644 --- a/internal/irzstd/memory.go +++ b/internal/irzstd/memory.go @@ -17,6 +17,7 @@ type memoryWriter struct { zstdBuffer *bytes.Buffer irWriter *ir.Writer zstdWriter *zstd.Encoder + state WriterState irTotalBytes int } @@ -43,6 +44,7 @@ func NewMemoryWriter() (*memoryWriter, error) { irWriter: irWriter, zstdWriter: zstdWriter, zstdBuffer: &zstdBuffer, + state: Open, } return &memoryWriter, nil @@ -57,6 +59,10 @@ func NewMemoryWriter() (*memoryWriter, error) { // - numEvents: Number of log events successfully written to IR writer buffer // - err: Error writing IR/Zstd func (w *memoryWriter) WriteIrZstd(logEvents []ffi.LogEvent) (int, error) { + if w.state != Open { + return 0, fmt.Errorf("cannot write: writer state is %s, expected %s", w.state, Open) + } + numBytes, numEvents, err := writeIr(w.irWriter, logEvents) w.irTotalBytes += numBytes if err != nil { @@ -71,12 +77,26 @@ func (w *memoryWriter) WriteIrZstd(logEvents []ffi.LogEvent) (int, error) { // Returns: // - err: Error closing buffers func (w *memoryWriter) CloseStreams() error { + if w.state == StreamsClosed { + return nil + } + if w.state != Open { + return fmt.Errorf("cannot close streams: writer state is %s, expected %s", w.state, Open) + } + if err := w.irWriter.Close(); err != nil { + w.state = Corrupted return err } w.irWriter = nil - return w.zstdWriter.Close() + if err := w.zstdWriter.Close(); err != nil { + w.state = Corrupted + return err + } + + w.state = StreamsClosed + return nil } // Reinitialize [memoryWriter] after calling CloseStreams(). Resets individual IR and Zstd writers @@ -85,6 +105,10 @@ func (w *memoryWriter) CloseStreams() error { // Returns: // - err: Error opening IR writer func (w *memoryWriter) Reset() error { + if w.state != StreamsClosed { + return fmt.Errorf("cannot reset: writer state is %s, expected %s", w.state, StreamsClosed) + } + var err error w.zstdBuffer.Reset() w.zstdWriter.Reset(w.zstdBuffer) @@ -92,9 +116,11 @@ func (w *memoryWriter) Reset() error { w.irWriter, err = ir.NewWriter[ir.FourByteEncoding](w.zstdWriter) if err != nil { + w.state = Corrupted return err } + w.state = Open return nil } @@ -117,6 +143,14 @@ func (w *memoryWriter) GetZstdOutputSize() (int, error) { return w.zstdBuffer.Len(), nil } +// Getter for state. +// +// Returns: +// - state: Current state +func (w *memoryWriter) GetState() WriterState { + return w.state +} + // Closes [memoryWriter]. Currently used during recovery only, and advise caution using elsewhere. // Using [ir.Writer.Serializer.Close] instead of [ir.Writer.Close] so EndofStream byte is not // added. It is preferable to add postamble on recovery so that IR is in the same state diff --git a/internal/irzstd/state.go b/internal/irzstd/state.go new file mode 100644 index 0000000..448d463 --- /dev/null +++ b/internal/irzstd/state.go @@ -0,0 +1,27 @@ +package irzstd + +// WriterState is the state of a [Writer]. +type WriterState int + +const ( + // Ready to accept writes. + Open WriterState = iota + // Streams are terminated and [Writer] must be [Reset] before writing again. + StreamsClosed + // There was an unrecoverable error and writer is unusable. + Corrupted +) + +var writerStateNames = map[WriterState]string{ + Open: "Open", + StreamsClosed: "StreamsClosed", + Corrupted: "Corrupted", +} + +// Getter for string representation of [WriterState]. +// +// Returns: +// - name: String representation of the state +func (s WriterState) String() string { + return writerStateNames[s] +} diff --git a/internal/irzstd/writer.go b/internal/irzstd/writer.go index cd2f96c..d973e41 100644 --- a/internal/irzstd/writer.go +++ b/internal/irzstd/writer.go @@ -54,6 +54,12 @@ type Writer interface { // - err GetZstdOutputSize() (int, error) + // Get the current state of the Writer. + // + // Returns: + // - state: Current state (Open, StreamsClosed, or Corrupted) + GetState() WriterState + // Checks if writer is empty. True if no events are buffered. // // Returns: