diff --git a/provider.go b/provider.go index 3a1fd3e..316bb6b 100644 --- a/provider.go +++ b/provider.go @@ -56,20 +56,32 @@ func (provider *Provider) selectColumns() string { }, ",") } +// GetEvents retrieves events from the provider +// eventIndexEnd 0 means max +// see https://github.com/asynkron/protoactor-go/blob/dev/persistence/plugin.go#L65 func (provider *Provider) GetEvents(actorName string, eventIndexStart int, eventIndexEnd int, callback func(e interface{})) { tx, _ := provider.db.BeginTx(provider.context, pgx.TxOptions{IsoLevel: pgx.Serializable}) defer tx.Commit(provider.context) - rows, err := tx.Query( - provider.context, - fmt.Sprintf( - "SELECT %s FROM %s WHERE %s = $1 AND %s BETWEEN $2 AND $3 ORDER BY %s ASC", + query := fmt.Sprintf( + "SELECT %s FROM %s WHERE %s = $1 AND %s BETWEEN $2 AND $3 ORDER BY %s ASC", + provider.selectColumns(), + provider.tableSchema.JournalTableName(), + provider.tableSchema.ActorName(), + provider.tableSchema.SequenceNumber(), + provider.tableSchema.SequenceNumber(), + ) + args := []interface{}{actorName, eventIndexStart, eventIndexEnd} + if eventIndexEnd == 0 { + query = fmt.Sprintf( + "SELECT %s FROM %s WHERE %s = $1 AND %s >= $2 ORDER BY %s ASC", provider.selectColumns(), provider.tableSchema.JournalTableName(), provider.tableSchema.ActorName(), provider.tableSchema.SequenceNumber(), - provider.tableSchema.SequenceNumber(), - ), - actorName, eventIndexStart, eventIndexEnd) + provider.tableSchema.SequenceNumber()) + args = []interface{}{actorName, eventIndexStart} + } + rows, err := tx.Query(provider.context, query, args...) if !errors.Is(err, sql.ErrNoRows) && err != nil { provider.logger.Error(err.Error(), slog.String("actor_name", actorName)) return diff --git a/provider_test.go b/provider_test.go index 6591b53..8f81d44 100644 --- a/provider_test.go +++ b/provider_test.go @@ -46,6 +46,18 @@ func TestProvider_PersistEvent(t *testing.T) { if !reflect.DeepEqual(evt, evv) { t.Errorf("unexpected event %v", evv) } + + var evv2 *testdata.UserCreated + provider.GetEvents("user", 1, 0, func(e interface{}) { + ev, ok := e.(*testdata.UserCreated) + if !ok { + t.Error("unexpected type") + } + evv2 = ev + }) + if !reflect.DeepEqual(evt, evv2) { + t.Errorf("unexpected event %v", evv2) + } } func TestProvider_PersistSnapshot(t *testing.T) {