diff --git a/pgtype/json.go b/pgtype/json.go index 99628092a..e71dcb9bf 100644 --- a/pgtype/json.go +++ b/pgtype/json.go @@ -8,17 +8,20 @@ import ( "reflect" ) -type JSONCodec struct{} +type JSONCodec struct { + Marshal func(v any) ([]byte, error) + Unmarshal func(data []byte, v any) error +} -func (JSONCodec) FormatSupported(format int16) bool { +func (*JSONCodec) FormatSupported(format int16) bool { return format == TextFormatCode || format == BinaryFormatCode } -func (JSONCodec) PreferredFormat() int16 { +func (*JSONCodec) PreferredFormat() int16 { return TextFormatCode } -func (c JSONCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { +func (c *JSONCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { switch value.(type) { case string: return encodePlanJSONCodecEitherFormatString{} @@ -44,7 +47,9 @@ func (c JSONCodec) PlanEncode(m *Map, oid uint32, format int16, value any) Encod // // https://github.com/jackc/pgx/issues/1681 case json.Marshaler: - return encodePlanJSONCodecEitherFormatMarshal{} + return &encodePlanJSONCodecEitherFormatMarshal{ + marshal: c.Marshal, + } } // Because anything can be marshalled the normal wrapping in Map.PlanScan doesn't get a chance to run. So try the @@ -61,7 +66,9 @@ func (c JSONCodec) PlanEncode(m *Map, oid uint32, format int16, value any) Encod } } - return encodePlanJSONCodecEitherFormatMarshal{} + return &encodePlanJSONCodecEitherFormatMarshal{ + marshal: c.Marshal, + } } type encodePlanJSONCodecEitherFormatString struct{} @@ -96,10 +103,12 @@ func (encodePlanJSONCodecEitherFormatJSONRawMessage) Encode(value any, buf []byt return buf, nil } -type encodePlanJSONCodecEitherFormatMarshal struct{} +type encodePlanJSONCodecEitherFormatMarshal struct { + marshal func(v any) ([]byte, error) +} -func (encodePlanJSONCodecEitherFormatMarshal) Encode(value any, buf []byte) (newBuf []byte, err error) { - jsonBytes, err := json.Marshal(value) +func (e *encodePlanJSONCodecEitherFormatMarshal) Encode(value any, buf []byte) (newBuf []byte, err error) { + jsonBytes, err := e.marshal(value) if err != nil { return nil, err } @@ -108,7 +117,7 @@ func (encodePlanJSONCodecEitherFormatMarshal) Encode(value any, buf []byte) (new return buf, nil } -func (JSONCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { +func (c *JSONCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { switch target.(type) { case *string: return scanPlanAnyToString{} @@ -141,7 +150,9 @@ func (JSONCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan return &scanPlanSQLScanner{formatCode: format} } - return scanPlanJSONToJSONUnmarshal{} + return &scanPlanJSONToJSONUnmarshal{ + unmarshal: c.Unmarshal, + } } type scanPlanAnyToString struct{} @@ -173,9 +184,11 @@ func (scanPlanJSONToBytesScanner) Scan(src []byte, dst any) error { return scanner.ScanBytes(src) } -type scanPlanJSONToJSONUnmarshal struct{} +type scanPlanJSONToJSONUnmarshal struct { + unmarshal func(data []byte, v any) error +} -func (scanPlanJSONToJSONUnmarshal) Scan(src []byte, dst any) error { +func (s *scanPlanJSONToJSONUnmarshal) Scan(src []byte, dst any) error { if src == nil { dstValue := reflect.ValueOf(dst) if dstValue.Kind() == reflect.Ptr { @@ -193,10 +206,10 @@ func (scanPlanJSONToJSONUnmarshal) Scan(src []byte, dst any) error { elem := reflect.ValueOf(dst).Elem() elem.Set(reflect.Zero(elem.Type())) - return json.Unmarshal(src, dst) + return s.unmarshal(src, dst) } -func (c JSONCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { +func (c *JSONCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { if src == nil { return nil, nil } @@ -206,12 +219,12 @@ func (c JSONCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src return dstBuf, nil } -func (c JSONCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { +func (c *JSONCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { if src == nil { return nil, nil } var dst any - err := json.Unmarshal(src, &dst) + err := c.Unmarshal(src, &dst) return dst, err } diff --git a/pgtype/json_test.go b/pgtype/json_test.go index 2474fd371..5111522f9 100644 --- a/pgtype/json_test.go +++ b/pgtype/json_test.go @@ -6,9 +6,11 @@ import ( "database/sql/driver" "encoding/json" "errors" + "reflect" "testing" pgx "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" "github.com/jackc/pgx/v5/pgxtest" "github.com/stretchr/testify/require" ) @@ -224,3 +226,26 @@ func TestJSONCodecEncodeJSONMarshalerThatCanBeWrapped(t *testing.T) { require.Equal(t, `{"custom":"thing"}`, jsonStr) }) } + +func TestJSONCodecCustomMarshal(t *testing.T) { + connTestRunner := defaultConnTestRunner + connTestRunner.AfterConnect = func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + conn.TypeMap().RegisterType(&pgtype.Type{ + Name: "json", OID: pgtype.JSONOID, Codec: &pgtype.JSONCodec{ + Marshal: func(v any) ([]byte, error) { + return []byte(`{"custom":"value"}`), nil + }, + Unmarshal: func(data []byte, v any) error { + return json.Unmarshal([]byte(`{"custom":"value"}`), v) + }, + }}) + } + + pgxtest.RunValueRoundTripTests(context.Background(), t, connTestRunner, pgxtest.KnownOIDQueryExecModes, "json", []pgxtest.ValueRoundTripTest{ + // There is no space between "custom" and "value" in json type. + {map[string]any{"something": "else"}, new(string), isExpectedEq(`{"custom":"value"}`)}, + {[]byte(`{"something":"else"}`), new(map[string]any), func(v any) bool { + return reflect.DeepEqual(v, map[string]any{"custom": "value"}) + }}, + }) +} diff --git a/pgtype/jsonb.go b/pgtype/jsonb.go index 25555e7ff..4d4eb58e5 100644 --- a/pgtype/jsonb.go +++ b/pgtype/jsonb.go @@ -2,29 +2,31 @@ package pgtype import ( "database/sql/driver" - "encoding/json" "fmt" ) -type JSONBCodec struct{} +type JSONBCodec struct { + Marshal func(v any) ([]byte, error) + Unmarshal func(data []byte, v any) error +} -func (JSONBCodec) FormatSupported(format int16) bool { +func (*JSONBCodec) FormatSupported(format int16) bool { return format == TextFormatCode || format == BinaryFormatCode } -func (JSONBCodec) PreferredFormat() int16 { +func (*JSONBCodec) PreferredFormat() int16 { return TextFormatCode } -func (JSONBCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { +func (c *JSONBCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { switch format { case BinaryFormatCode: - plan := JSONCodec{}.PlanEncode(m, oid, TextFormatCode, value) + plan := (&JSONCodec{Marshal: c.Marshal, Unmarshal: c.Unmarshal}).PlanEncode(m, oid, TextFormatCode, value) if plan != nil { return &encodePlanJSONBCodecBinaryWrapper{textPlan: plan} } case TextFormatCode: - return JSONCodec{}.PlanEncode(m, oid, format, value) + return (&JSONCodec{Marshal: c.Marshal, Unmarshal: c.Unmarshal}).PlanEncode(m, oid, format, value) } return nil @@ -39,15 +41,15 @@ func (plan *encodePlanJSONBCodecBinaryWrapper) Encode(value any, buf []byte) (ne return plan.textPlan.Encode(value, buf) } -func (JSONBCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { +func (c *JSONBCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { switch format { case BinaryFormatCode: - plan := JSONCodec{}.PlanScan(m, oid, TextFormatCode, target) + plan := (&JSONCodec{Marshal: c.Marshal, Unmarshal: c.Unmarshal}).PlanScan(m, oid, TextFormatCode, target) if plan != nil { return &scanPlanJSONBCodecBinaryUnwrapper{textPlan: plan} } case TextFormatCode: - return JSONCodec{}.PlanScan(m, oid, format, target) + return (&JSONCodec{Marshal: c.Marshal, Unmarshal: c.Unmarshal}).PlanScan(m, oid, format, target) } return nil @@ -73,7 +75,7 @@ func (plan *scanPlanJSONBCodecBinaryUnwrapper) Scan(src []byte, dst any) error { return plan.textPlan.Scan(src[1:], dst) } -func (c JSONBCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { +func (c *JSONBCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { if src == nil { return nil, nil } @@ -100,7 +102,7 @@ func (c JSONBCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src } } -func (c JSONBCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { +func (c *JSONBCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { if src == nil { return nil, nil } @@ -122,6 +124,6 @@ func (c JSONBCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (a } var dst any - err := json.Unmarshal(src, &dst) + err := c.Unmarshal(src, &dst) return dst, err } diff --git a/pgtype/jsonb_test.go b/pgtype/jsonb_test.go index 5bfbdbe3a..f34b3ad51 100644 --- a/pgtype/jsonb_test.go +++ b/pgtype/jsonb_test.go @@ -2,9 +2,12 @@ package pgtype_test import ( "context" + "encoding/json" + "reflect" "testing" pgx "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" "github.com/jackc/pgx/v5/pgxtest" "github.com/stretchr/testify/require" ) @@ -80,3 +83,26 @@ func TestJSONBCodecEncodeJSONMarshalerThatCanBeWrapped(t *testing.T) { require.Equal(t, `{"custom": "thing"}`, jsonStr) // Note that unlike json, jsonb reformats the JSON string. }) } + +func TestJSONBCodecCustomMarshal(t *testing.T) { + connTestRunner := defaultConnTestRunner + connTestRunner.AfterConnect = func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + conn.TypeMap().RegisterType(&pgtype.Type{ + Name: "jsonb", OID: pgtype.JSONBOID, Codec: &pgtype.JSONBCodec{ + Marshal: func(v any) ([]byte, error) { + return []byte(`{"custom":"value"}`), nil + }, + Unmarshal: func(data []byte, v any) error { + return json.Unmarshal([]byte(`{"custom":"value"}`), v) + }, + }}) + } + + pgxtest.RunValueRoundTripTests(context.Background(), t, connTestRunner, pgxtest.KnownOIDQueryExecModes, "jsonb", []pgxtest.ValueRoundTripTest{ + // There is space between "custom" and "value" in jsonb type. + {map[string]any{"something": "else"}, new(string), isExpectedEq(`{"custom": "value"}`)}, + {[]byte(`{"something":"else"}`), new(map[string]any), func(v any) bool { + return reflect.DeepEqual(v, map[string]any{"custom": "value"}) + }}, + }) +} diff --git a/pgtype/pgtype_default.go b/pgtype/pgtype_default.go index 38093ef42..9525f37c9 100644 --- a/pgtype/pgtype_default.go +++ b/pgtype/pgtype_default.go @@ -65,8 +65,8 @@ func initDefaultMap() { defaultMap.RegisterType(&Type{Name: "int4", OID: Int4OID, Codec: Int4Codec{}}) defaultMap.RegisterType(&Type{Name: "int8", OID: Int8OID, Codec: Int8Codec{}}) defaultMap.RegisterType(&Type{Name: "interval", OID: IntervalOID, Codec: IntervalCodec{}}) - defaultMap.RegisterType(&Type{Name: "json", OID: JSONOID, Codec: JSONCodec{}}) - defaultMap.RegisterType(&Type{Name: "jsonb", OID: JSONBOID, Codec: JSONBCodec{}}) + defaultMap.RegisterType(&Type{Name: "json", OID: JSONOID, Codec: &JSONCodec{Marshal: json.Marshal, Unmarshal: json.Unmarshal}}) + defaultMap.RegisterType(&Type{Name: "jsonb", OID: JSONBOID, Codec: &JSONBCodec{Marshal: json.Marshal, Unmarshal: json.Unmarshal}}) defaultMap.RegisterType(&Type{Name: "jsonpath", OID: JSONPathOID, Codec: &TextFormatOnlyCodec{TextCodec{}}}) defaultMap.RegisterType(&Type{Name: "line", OID: LineOID, Codec: LineCodec{}}) defaultMap.RegisterType(&Type{Name: "lseg", OID: LsegOID, Codec: LsegCodec{}})