From 8649231bb3bc00b4b9c180ce557a54ae41c28ce2 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 16 Mar 2024 13:11:13 -0500 Subject: [PATCH] Add ScanLocation to pgtype.TimestampCodec If ScanLocation is set, the timestamps will be assumed to be in the given location when scanning from the database. The Codec interface is now implemented by *pgtype.TimestampCodec instead of pgtype.TimestampCodec. This is technically a breaking change, but it is extremely unlikely that anyone is depending on this, and if there is downstream breakage it is trivial to fix. https://github.com/jackc/pgx/issues/1195 https://github.com/jackc/pgx/issues/1945 --- pgtype/pgtype_default.go | 2 +- pgtype/timestamp.go | 39 +++++++++++++++++++++++++-------------- pgtype/timestamp_test.go | 36 ++++++++++++++++++++++++++++++++++++ 3 files changed, 62 insertions(+), 15 deletions(-) diff --git a/pgtype/pgtype_default.go b/pgtype/pgtype_default.go index c5f2b3ce1..38093ef42 100644 --- a/pgtype/pgtype_default.go +++ b/pgtype/pgtype_default.go @@ -82,7 +82,7 @@ func initDefaultMap() { defaultMap.RegisterType(&Type{Name: "text", OID: TextOID, Codec: TextCodec{}}) defaultMap.RegisterType(&Type{Name: "tid", OID: TIDOID, Codec: TIDCodec{}}) defaultMap.RegisterType(&Type{Name: "time", OID: TimeOID, Codec: TimeCodec{}}) - defaultMap.RegisterType(&Type{Name: "timestamp", OID: TimestampOID, Codec: TimestampCodec{}}) + defaultMap.RegisterType(&Type{Name: "timestamp", OID: TimestampOID, Codec: &TimestampCodec{}}) defaultMap.RegisterType(&Type{Name: "timestamptz", OID: TimestamptzOID, Codec: &TimestamptzCodec{}}) defaultMap.RegisterType(&Type{Name: "unknown", OID: UnknownOID, Codec: TextCodec{}}) defaultMap.RegisterType(&Type{Name: "uuid", OID: UUIDOID, Codec: UUIDCodec{}}) diff --git a/pgtype/timestamp.go b/pgtype/timestamp.go index 35d739566..677a2c6ea 100644 --- a/pgtype/timestamp.go +++ b/pgtype/timestamp.go @@ -46,7 +46,7 @@ func (ts *Timestamp) Scan(src any) error { switch src := src.(type) { case string: - return scanPlanTextTimestampToTimestampScanner{}.Scan([]byte(src), ts) + return (&scanPlanTextTimestampToTimestampScanner{}).Scan([]byte(src), ts) case time.Time: *ts = Timestamp{Time: src, Valid: true} return nil @@ -116,17 +116,21 @@ func (ts *Timestamp) UnmarshalJSON(b []byte) error { return nil } -type TimestampCodec struct{} +type TimestampCodec struct { + // ScanLocation is the location that the time is assumed to be in for scanning. This is different from + // TimestamptzCodec.ScanLocation in that this setting does change the instant in time that the timestamp represents. + ScanLocation *time.Location +} -func (TimestampCodec) FormatSupported(format int16) bool { +func (*TimestampCodec) FormatSupported(format int16) bool { return format == TextFormatCode || format == BinaryFormatCode } -func (TimestampCodec) PreferredFormat() int16 { +func (*TimestampCodec) PreferredFormat() int16 { return BinaryFormatCode } -func (TimestampCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { +func (*TimestampCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { if _, ok := value.(TimestampValuer); !ok { return nil } @@ -220,27 +224,27 @@ func discardTimeZone(t time.Time) time.Time { return t } -func (TimestampCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { +func (c *TimestampCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { switch format { case BinaryFormatCode: switch target.(type) { case TimestampScanner: - return scanPlanBinaryTimestampToTimestampScanner{} + return &scanPlanBinaryTimestampToTimestampScanner{location: c.ScanLocation} } case TextFormatCode: switch target.(type) { case TimestampScanner: - return scanPlanTextTimestampToTimestampScanner{} + return &scanPlanTextTimestampToTimestampScanner{location: c.ScanLocation} } } return nil } -type scanPlanBinaryTimestampToTimestampScanner struct{} +type scanPlanBinaryTimestampToTimestampScanner struct{ location *time.Location } -func (scanPlanBinaryTimestampToTimestampScanner) Scan(src []byte, dst any) error { +func (plan *scanPlanBinaryTimestampToTimestampScanner) Scan(src []byte, dst any) error { scanner := (dst).(TimestampScanner) if src == nil { @@ -264,15 +268,18 @@ func (scanPlanBinaryTimestampToTimestampScanner) Scan(src []byte, dst any) error microsecFromUnixEpochToY2K/1000000+microsecSinceY2K/1000000, (microsecFromUnixEpochToY2K%1000000*1000)+(microsecSinceY2K%1000000*1000), ).UTC() + if plan.location != nil { + tim = time.Date(tim.Year(), tim.Month(), tim.Day(), tim.Hour(), tim.Minute(), tim.Second(), tim.Nanosecond(), plan.location) + } ts = Timestamp{Time: tim, Valid: true} } return scanner.ScanTimestamp(ts) } -type scanPlanTextTimestampToTimestampScanner struct{} +type scanPlanTextTimestampToTimestampScanner struct{ location *time.Location } -func (scanPlanTextTimestampToTimestampScanner) Scan(src []byte, dst any) error { +func (plan *scanPlanTextTimestampToTimestampScanner) Scan(src []byte, dst any) error { scanner := (dst).(TimestampScanner) if src == nil { @@ -302,13 +309,17 @@ func (scanPlanTextTimestampToTimestampScanner) Scan(src []byte, dst any) error { tim = time.Date(year, tim.Month(), tim.Day(), tim.Hour(), tim.Minute(), tim.Second(), tim.Nanosecond(), tim.Location()) } + if plan.location != nil { + tim = time.Date(tim.Year(), tim.Month(), tim.Day(), tim.Hour(), tim.Minute(), tim.Second(), tim.Nanosecond(), plan.location) + } + ts = Timestamp{Time: tim, Valid: true} } return scanner.ScanTimestamp(ts) } -func (c TimestampCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { +func (c *TimestampCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { if src == nil { return nil, nil } @@ -326,7 +337,7 @@ func (c TimestampCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, return ts.Time, nil } -func (c TimestampCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { +func (c *TimestampCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { if src == nil { return nil, nil } diff --git a/pgtype/timestamp_test.go b/pgtype/timestamp_test.go index 24f229d54..31b3ad822 100644 --- a/pgtype/timestamp_test.go +++ b/pgtype/timestamp_test.go @@ -38,6 +38,42 @@ func TestTimestampCodec(t *testing.T) { }) } +func TestTimestampCodecWithScanLocationUTC(t *testing.T) { + skipCockroachDB(t, "Server does not support infinite timestamps (see https://github.com/cockroachdb/cockroach/issues/41564)") + + connTestRunner := defaultConnTestRunner + connTestRunner.AfterConnect = func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + conn.TypeMap().RegisterType(&pgtype.Type{ + Name: "timestamp", + OID: pgtype.TimestampOID, + Codec: &pgtype.TimestampCodec{ScanLocation: time.UTC}, + }) + } + + pgxtest.RunValueRoundTripTests(context.Background(), t, connTestRunner, nil, "timestamp", []pgxtest.ValueRoundTripTest{ + // Have to use pgtype.Timestamp instead of time.Time as source because otherwise the simple and exec query exec + // modes will encode the time for timestamptz. That is, they will convert it from local time zone. + {pgtype.Timestamp{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), Valid: true}, new(time.Time), isExpectedEq(time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC))}, + }) +} + +func TestTimestampCodecWithScanLocationLocal(t *testing.T) { + skipCockroachDB(t, "Server does not support infinite timestamps (see https://github.com/cockroachdb/cockroach/issues/41564)") + + connTestRunner := defaultConnTestRunner + connTestRunner.AfterConnect = func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + conn.TypeMap().RegisterType(&pgtype.Type{ + Name: "timestamp", + OID: pgtype.TimestampOID, + Codec: &pgtype.TimestampCodec{ScanLocation: time.Local}, + }) + } + + pgxtest.RunValueRoundTripTests(context.Background(), t, connTestRunner, nil, "timestamp", []pgxtest.ValueRoundTripTest{ + {time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEq(time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local))}, + }) +} + // https://github.com/jackc/pgx/v4/pgtype/pull/128 func TestTimestampTranscodeBigTimeBinary(t *testing.T) { defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {