diff --git a/pgtype/pgtype_default.go b/pgtype/pgtype_default.go index c5f2b3ce1..38093ef42 100644 --- a/pgtype/pgtype_default.go +++ b/pgtype/pgtype_default.go @@ -82,7 +82,7 @@ func initDefaultMap() { defaultMap.RegisterType(&Type{Name: "text", OID: TextOID, Codec: TextCodec{}}) defaultMap.RegisterType(&Type{Name: "tid", OID: TIDOID, Codec: TIDCodec{}}) defaultMap.RegisterType(&Type{Name: "time", OID: TimeOID, Codec: TimeCodec{}}) - defaultMap.RegisterType(&Type{Name: "timestamp", OID: TimestampOID, Codec: TimestampCodec{}}) + defaultMap.RegisterType(&Type{Name: "timestamp", OID: TimestampOID, Codec: &TimestampCodec{}}) defaultMap.RegisterType(&Type{Name: "timestamptz", OID: TimestamptzOID, Codec: &TimestamptzCodec{}}) defaultMap.RegisterType(&Type{Name: "unknown", OID: UnknownOID, Codec: TextCodec{}}) defaultMap.RegisterType(&Type{Name: "uuid", OID: UUIDOID, Codec: UUIDCodec{}}) diff --git a/pgtype/timestamp.go b/pgtype/timestamp.go index 35d739566..677a2c6ea 100644 --- a/pgtype/timestamp.go +++ b/pgtype/timestamp.go @@ -46,7 +46,7 @@ func (ts *Timestamp) Scan(src any) error { switch src := src.(type) { case string: - return scanPlanTextTimestampToTimestampScanner{}.Scan([]byte(src), ts) + return (&scanPlanTextTimestampToTimestampScanner{}).Scan([]byte(src), ts) case time.Time: *ts = Timestamp{Time: src, Valid: true} return nil @@ -116,17 +116,21 @@ func (ts *Timestamp) UnmarshalJSON(b []byte) error { return nil } -type TimestampCodec struct{} +type TimestampCodec struct { + // ScanLocation is the location that the time is assumed to be in for scanning. This is different from + // TimestamptzCodec.ScanLocation in that this setting does change the instant in time that the timestamp represents. + ScanLocation *time.Location +} -func (TimestampCodec) FormatSupported(format int16) bool { +func (*TimestampCodec) FormatSupported(format int16) bool { return format == TextFormatCode || format == BinaryFormatCode } -func (TimestampCodec) PreferredFormat() int16 { +func (*TimestampCodec) PreferredFormat() int16 { return BinaryFormatCode } -func (TimestampCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { +func (*TimestampCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { if _, ok := value.(TimestampValuer); !ok { return nil } @@ -220,27 +224,27 @@ func discardTimeZone(t time.Time) time.Time { return t } -func (TimestampCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { +func (c *TimestampCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { switch format { case BinaryFormatCode: switch target.(type) { case TimestampScanner: - return scanPlanBinaryTimestampToTimestampScanner{} + return &scanPlanBinaryTimestampToTimestampScanner{location: c.ScanLocation} } case TextFormatCode: switch target.(type) { case TimestampScanner: - return scanPlanTextTimestampToTimestampScanner{} + return &scanPlanTextTimestampToTimestampScanner{location: c.ScanLocation} } } return nil } -type scanPlanBinaryTimestampToTimestampScanner struct{} +type scanPlanBinaryTimestampToTimestampScanner struct{ location *time.Location } -func (scanPlanBinaryTimestampToTimestampScanner) Scan(src []byte, dst any) error { +func (plan *scanPlanBinaryTimestampToTimestampScanner) Scan(src []byte, dst any) error { scanner := (dst).(TimestampScanner) if src == nil { @@ -264,15 +268,18 @@ func (scanPlanBinaryTimestampToTimestampScanner) Scan(src []byte, dst any) error microsecFromUnixEpochToY2K/1000000+microsecSinceY2K/1000000, (microsecFromUnixEpochToY2K%1000000*1000)+(microsecSinceY2K%1000000*1000), ).UTC() + if plan.location != nil { + tim = time.Date(tim.Year(), tim.Month(), tim.Day(), tim.Hour(), tim.Minute(), tim.Second(), tim.Nanosecond(), plan.location) + } ts = Timestamp{Time: tim, Valid: true} } return scanner.ScanTimestamp(ts) } -type scanPlanTextTimestampToTimestampScanner struct{} +type scanPlanTextTimestampToTimestampScanner struct{ location *time.Location } -func (scanPlanTextTimestampToTimestampScanner) Scan(src []byte, dst any) error { +func (plan *scanPlanTextTimestampToTimestampScanner) Scan(src []byte, dst any) error { scanner := (dst).(TimestampScanner) if src == nil { @@ -302,13 +309,17 @@ func (scanPlanTextTimestampToTimestampScanner) Scan(src []byte, dst any) error { tim = time.Date(year, tim.Month(), tim.Day(), tim.Hour(), tim.Minute(), tim.Second(), tim.Nanosecond(), tim.Location()) } + if plan.location != nil { + tim = time.Date(tim.Year(), tim.Month(), tim.Day(), tim.Hour(), tim.Minute(), tim.Second(), tim.Nanosecond(), plan.location) + } + ts = Timestamp{Time: tim, Valid: true} } return scanner.ScanTimestamp(ts) } -func (c TimestampCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { +func (c *TimestampCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { if src == nil { return nil, nil } @@ -326,7 +337,7 @@ func (c TimestampCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, return ts.Time, nil } -func (c TimestampCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { +func (c *TimestampCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { if src == nil { return nil, nil } diff --git a/pgtype/timestamp_test.go b/pgtype/timestamp_test.go index 24f229d54..31b3ad822 100644 --- a/pgtype/timestamp_test.go +++ b/pgtype/timestamp_test.go @@ -38,6 +38,42 @@ func TestTimestampCodec(t *testing.T) { }) } +func TestTimestampCodecWithScanLocationUTC(t *testing.T) { + skipCockroachDB(t, "Server does not support infinite timestamps (see https://github.com/cockroachdb/cockroach/issues/41564)") + + connTestRunner := defaultConnTestRunner + connTestRunner.AfterConnect = func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + conn.TypeMap().RegisterType(&pgtype.Type{ + Name: "timestamp", + OID: pgtype.TimestampOID, + Codec: &pgtype.TimestampCodec{ScanLocation: time.UTC}, + }) + } + + pgxtest.RunValueRoundTripTests(context.Background(), t, connTestRunner, nil, "timestamp", []pgxtest.ValueRoundTripTest{ + // Have to use pgtype.Timestamp instead of time.Time as source because otherwise the simple and exec query exec + // modes will encode the time for timestamptz. That is, they will convert it from local time zone. + {pgtype.Timestamp{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), Valid: true}, new(time.Time), isExpectedEq(time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC))}, + }) +} + +func TestTimestampCodecWithScanLocationLocal(t *testing.T) { + skipCockroachDB(t, "Server does not support infinite timestamps (see https://github.com/cockroachdb/cockroach/issues/41564)") + + connTestRunner := defaultConnTestRunner + connTestRunner.AfterConnect = func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + conn.TypeMap().RegisterType(&pgtype.Type{ + Name: "timestamp", + OID: pgtype.TimestampOID, + Codec: &pgtype.TimestampCodec{ScanLocation: time.Local}, + }) + } + + pgxtest.RunValueRoundTripTests(context.Background(), t, connTestRunner, nil, "timestamp", []pgxtest.ValueRoundTripTest{ + {time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEq(time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local))}, + }) +} + // https://github.com/jackc/pgx/v4/pgtype/pull/128 func TestTimestampTranscodeBigTimeBinary(t *testing.T) { defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {