diff --git a/api/contract.go b/api/contract.go index 5606b1b35..4e8f74099 100644 --- a/api/contract.go +++ b/api/contract.go @@ -2,6 +2,7 @@ package api import ( "errors" + "time" rhpv2 "go.sia.tech/core/rhp/v2" "go.sia.tech/core/types" @@ -50,8 +51,9 @@ type ( // ContractMetadata contains all metadata for a contract. ContractMetadata struct { - ID types.FileContractID `json:"id"` - HostKey types.PublicKey `json:"hostKey"` + CreatedAt time.Time `json:"createdAt"` + ID types.FileContractID `json:"id"` + HostKey types.PublicKey `json:"hostKey"` ProofHeight uint64 `json:"proofHeight"` RenewedFrom types.FileContractID `json:"renewedFrom"` @@ -119,9 +121,9 @@ type ( // ContractAddRequest is the request type for the /contract/:id endpoint. ContractAddRequest struct { - Revision rhpv2.ContractRevision `json:"revision"` ContractPrice types.Currency `json:"contractPrice"` InitialRenterFunds types.Currency `json:"initialRenterFunds"` + Revision rhpv2.ContractRevision `json:"revision"` StartHeight uint64 `json:"startHeight"` State string `json:"state,omitempty"` } @@ -174,17 +176,6 @@ type ( RenterFunds types.Currency `json:"renterFunds"` } - // ContractRenewedRequest is the request type for the /contract/:id/renewed - // endpoint. - ContractRenewedRequest struct { - Contract rhpv2.ContractRevision `json:"contract"` - ContractPrice types.Currency `json:"contractPrice"` - InitialRenterFunds types.Currency `json:"initialRenterFunds"` - RenewedFrom types.FileContractID `json:"renewedFrom"` - StartHeight uint64 `json:"startHeight"` - State string `json:"state,omitempty"` - } - // ContractRootsResponse is the response type for the /contract/:id/roots // endpoint. ContractRootsResponse struct { diff --git a/bus/bus.go b/bus/bus.go index 420862d9e..6c4b3dc0c 100644 --- a/bus/bus.go +++ b/bus/bus.go @@ -210,7 +210,7 @@ type ( // A MetadataStore stores information about contracts and objects. MetadataStore interface { - AddContract(ctx context.Context, c api.ContractMetadata) error + AddRenewal(ctx context.Context, c api.ContractMetadata) error AncestorContracts(ctx context.Context, fcid types.FileContractID, minStartHeight uint64) ([]api.ContractMetadata, error) ArchiveContract(ctx context.Context, id types.FileContractID, reason string) error ArchiveContracts(ctx context.Context, toArchive map[types.FileContractID]string) error @@ -218,9 +218,10 @@ type ( Contract(ctx context.Context, id types.FileContractID) (api.ContractMetadata, error) Contracts(ctx context.Context, opts api.ContractsOpts) ([]api.ContractMetadata, error) ContractSets(ctx context.Context) ([]string, error) + InsertContract(ctx context.Context, c api.ContractMetadata) error RecordContractSpending(ctx context.Context, records []api.ContractSpendingRecord) error RemoveContractSet(ctx context.Context, name string) error - RenewContract(ctx context.Context, c api.ContractMetadata) error + PutContract(ctx context.Context, c api.ContractMetadata) error RenewedContract(ctx context.Context, renewedFrom types.FileContractID) (api.ContractMetadata, error) UpdateContractSet(ctx context.Context, set string, toAdd, toRemove []types.FileContractID) error @@ -410,10 +411,11 @@ func (b *Bus) Handler() http.Handler { "GET /consensus/siafundfee/:payout": b.contractTaxHandlerGET, "GET /consensus/state": b.consensusStateHandler, - "POST /contracts": b.contractsFormHandler, + "PUT /contracts": b.contractsHandlerPUT, "GET /contracts": b.contractsHandlerGET, "DELETE /contracts/all": b.contractsAllHandlerDELETE, "POST /contracts/archive": b.contractsArchiveHandlerPOST, + "POST /contracts/form": b.contractsFormHandler, "GET /contracts/prunable": b.contractsPrunableDataHandlerGET, "GET /contracts/renewed/:id": b.contractsRenewedIDHandlerGET, "GET /contracts/sets": b.contractsSetsHandlerGET, @@ -421,7 +423,6 @@ func (b *Bus) Handler() http.Handler { "DELETE /contracts/set/:set": b.contractsSetHandlerDELETE, "POST /contracts/spending": b.contractsSpendingHandlerPOST, "GET /contract/:id": b.contractIDHandlerGET, - "POST /contract/:id": b.contractIDHandlerPOST, "DELETE /contract/:id": b.contractIDHandlerDELETE, "POST /contract/:id/acquire": b.contractAcquireHandlerPOST, "GET /contract/:id/ancestors": b.contractIDAncestorsHandler, @@ -532,7 +533,7 @@ func (b *Bus) Shutdown(ctx context.Context) error { } func (b *Bus) addContract(ctx context.Context, rev rhpv2.ContractRevision, contractPrice, initialRenterFunds types.Currency, startHeight uint64, state string) (api.ContractMetadata, error) { - if err := b.ms.AddContract(ctx, api.ContractMetadata{ + if err := b.ms.InsertContract(ctx, api.ContractMetadata{ ID: rev.ID(), HostKey: rev.HostKey(), StartHeight: startHeight, @@ -563,7 +564,7 @@ func (b *Bus) addContract(ctx context.Context, rev rhpv2.ContractRevision, contr } func (b *Bus) addRenewal(ctx context.Context, renewedFrom types.FileContractID, rev rhpv2.ContractRevision, contractPrice, initialRenterFunds types.Currency, startHeight uint64, state string) (api.ContractMetadata, error) { - if err := b.ms.RenewContract(ctx, api.ContractMetadata{ + if err := b.ms.AddRenewal(ctx, api.ContractMetadata{ ID: rev.ID(), HostKey: rev.HostKey(), RenewedFrom: renewedFrom, diff --git a/bus/client/contracts.go b/bus/client/contracts.go index 5789eb035..e29eaaf71 100644 --- a/bus/client/contracts.go +++ b/bus/client/contracts.go @@ -6,21 +6,14 @@ import ( "net/url" "time" - rhpv2 "go.sia.tech/core/rhp/v2" "go.sia.tech/core/types" "go.sia.tech/renterd/api" ) -// AddContract adds the provided contract to the metadata store. -func (c *Client) AddContract(ctx context.Context, revision rhpv2.ContractRevision, contractPrice, initialRenterFunds types.Currency, startHeight uint64, state string) (added api.ContractMetadata, err error) { - err = c.c.WithContext(ctx).POST(fmt.Sprintf("/contract/%s", revision.ID()), api.ContractAddRequest{ - Revision: revision, - ContractPrice: contractPrice, - InitialRenterFunds: initialRenterFunds, - StartHeight: startHeight, - State: state, - }, &added) - return +// AddContract adds the provided contract to the metadata store, if the contract +// already exists it will be replaced. +func (c *Client) AddContract(ctx context.Context, contract api.ContractMetadata) error { + return c.c.WithContext(ctx).PUT("/contracts", contract) } // AncestorContracts returns any ancestors of a given contract. @@ -128,7 +121,7 @@ func (c *Client) DeleteContractSet(ctx context.Context, set string) (err error) // FormContract forms a contract with a host and adds it to the bus. func (c *Client) FormContract(ctx context.Context, renterAddress types.Address, renterFunds types.Currency, hostKey types.PublicKey, hostIP string, hostCollateral types.Currency, endHeight uint64) (contract api.ContractMetadata, err error) { - err = c.c.WithContext(ctx).POST("/contracts", api.ContractFormRequest{ + err = c.c.WithContext(ctx).POST("/contracts/form", api.ContractFormRequest{ EndHeight: endHeight, HostCollateral: hostCollateral, HostKey: hostKey, diff --git a/bus/routes.go b/bus/routes.go index 0f4efe725..a3f89cbbb 100644 --- a/bus/routes.go +++ b/bus/routes.go @@ -763,7 +763,7 @@ func (b *Bus) contractsHandlerGET(jc jape.Context) { case api.ContractFilterModeActive: case api.ContractFilterModeArchived: default: - jc.Error(fmt.Errorf("invalid filter mode: %v", filterMode), http.StatusBadRequest) + jc.Error(fmt.Errorf("invalid filter mode: '%v'", filterMode), http.StatusBadRequest) return } @@ -1088,36 +1088,23 @@ func (b *Bus) contractIDHandlerGET(jc jape.Context) { } } -func (b *Bus) contractIDHandlerPOST(jc jape.Context) { - // decode parameters - var id types.FileContractID - if jc.DecodeParam("id", &id) != nil { - return - } - var req api.ContractAddRequest - if jc.Decode(&req) != nil { - return - } - - // validate the request - if req.InitialRenterFunds.IsZero() { - http.Error(jc.ResponseWriter, "InitialRenterFunds can not be zero", http.StatusBadRequest) - return - } else if req.Revision.ID() != id { - http.Error(jc.ResponseWriter, "Contract ID missmatch", http.StatusBadRequest) - return - } else if req.Revision.ID() == (types.FileContractID{}) { - http.Error(jc.ResponseWriter, "Contract ID is required", http.StatusBadRequest) - return - } else if req.Revision.HostKey() == (types.PublicKey{}) { - http.Error(jc.ResponseWriter, "HostKey is required", http.StatusBadRequest) +func (b *Bus) contractsHandlerPUT(jc jape.Context) { + // decode request + var c api.ContractMetadata + if jc.Decode(&c) != nil { return } - // add the contract - metadata, err := b.addContract(jc.Request.Context(), req.Revision, req.ContractPrice, req.InitialRenterFunds, req.StartHeight, req.State) - if jc.Check("couldn't add contract", err) == nil { - jc.Encode(metadata) + // upsert the contract + if jc.Check("failed to add contract", b.ms.PutContract(jc.Request.Context(), c)) == nil { + b.broadcastAction(webhooks.Event{ + Module: api.ModuleContract, + Event: api.EventAdd, + Payload: api.EventContractAdd{ + Added: c, + Timestamp: time.Now().UTC(), + }, + }) } } diff --git a/internal/test/e2e/cluster_test.go b/internal/test/e2e/cluster_test.go index 5a424aa3c..8648c716e 100644 --- a/internal/test/e2e/cluster_test.go +++ b/internal/test/e2e/cluster_test.go @@ -1517,7 +1517,7 @@ func TestUnconfirmedContractArchival(t *testing.T) { c := contracts[0] // add a contract to the bus - err = cluster.bs.AddContract(context.Background(), api.ContractMetadata{ + err = cluster.bs.InsertContract(context.Background(), api.ContractMetadata{ ID: types.FileContractID{1}, HostKey: types.PublicKey{1}, StartHeight: cs.BlockHeight, diff --git a/stores/metadata.go b/stores/metadata.go index 2832d9d1e..d1c5ab557 100644 --- a/stores/metadata.go +++ b/stores/metadata.go @@ -108,6 +108,27 @@ func (s *SQLStore) SlabBuffers(ctx context.Context) ([]api.SlabBuffer, error) { return buffers, nil } +func (s *SQLStore) AddRenewal(ctx context.Context, c api.ContractMetadata) error { + return s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { + // fetch renewed contract + renewed, err := tx.Contract(ctx, c.RenewedFrom) + if err != nil { + return err + } + + // insert renewal by updating the renewed contract + err = tx.UpdateContract(ctx, c.RenewedFrom, c) + if err != nil { + return err + } + + // reinsert renewed contract + renewed.ArchivalReason = api.ContractArchivalReasonRenewed + renewed.RenewedTo = c.ID + return tx.InsertContract(ctx, renewed) + }) +} + func (s *SQLStore) AncestorContracts(ctx context.Context, id types.FileContractID, startHeight uint64) (ancestors []api.ContractMetadata, err error) { err = s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { ancestors, err = tx.AncestorContracts(ctx, id, startHeight) @@ -208,12 +229,18 @@ func (s *SQLStore) ContractSize(ctx context.Context, id types.FileContractID) (c return cs, err } -func (s *SQLStore) AddContract(ctx context.Context, c api.ContractMetadata) error { +func (s *SQLStore) InsertContract(ctx context.Context, c api.ContractMetadata) error { return s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { return tx.InsertContract(ctx, c) }) } +func (s *SQLStore) PutContract(ctx context.Context, c api.ContractMetadata) error { + return s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { + return tx.PutContract(ctx, c) + }) +} + func (s *SQLStore) UpdateContractSet(ctx context.Context, name string, toAdd, toRemove []types.FileContractID) error { toAddMap := make(map[types.FileContractID]struct{}) for _, fcid := range toAdd { @@ -556,27 +583,6 @@ func (s *SQLStore) RemoveObjects(ctx context.Context, bucket, prefix string) err return nil } -func (s *SQLStore) RenewContract(ctx context.Context, c api.ContractMetadata) error { - return s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { - // fetch renewed contract - renewed, err := tx.Contract(ctx, c.RenewedFrom) - if err != nil { - return err - } - - // insert renewal by updating the renewed contract - err = tx.UpdateContract(ctx, c.RenewedFrom, c) - if err != nil { - return err - } - - // reinsert renewed contract - renewed.ArchivalReason = api.ContractArchivalReasonRenewed - renewed.RenewedTo = c.ID - return tx.InsertContract(ctx, renewed) - }) -} - func (s *SQLStore) Slab(ctx context.Context, key object.EncryptionKey) (slab object.Slab, err error) { err = s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { slab, err = tx.Slab(ctx, key) diff --git a/stores/metadata_test.go b/stores/metadata_test.go index 3c7e93cfc..d4653e300 100644 --- a/stores/metadata_test.go +++ b/stores/metadata_test.go @@ -399,18 +399,20 @@ func TestSQLContractStore(t *testing.T) { Uploads: types.NewCurrency64(6), }, } - if err := ss.AddContract(context.Background(), c); err != nil { + if err := ss.InsertContract(context.Background(), c); err != nil { t.Fatal(err) } - // decorate the host IP - c.HostIP = "address" - // fetch the contract inserted, err := ss.Contract(context.Background(), fcid) if err != nil { t.Fatal(err) - } else if !reflect.DeepEqual(inserted, c) { + } + + // assert it's equal + c.CreatedAt = inserted.CreatedAt + c.HostIP = inserted.HostIP + if !reflect.DeepEqual(inserted, c) { t.Fatal("contract mismatch", cmp.Diff(inserted, c)) } @@ -528,11 +530,7 @@ func TestContractRoots(t *testing.T) { // TestAncestorsContracts verifies that AncestorContracts returns the right // ancestors in the correct order. func TestAncestorsContracts(t *testing.T) { - cfg := defaultTestSQLStoreConfig - cfg.persistent = true - cfg.dir = "/Users/peterjan/testing3" - os.RemoveAll(cfg.dir) - ss := newTestSQLStore(t, cfg) + ss := newTestSQLStore(t, defaultTestSQLStoreConfig) defer ss.Close() hk := types.PublicKey{1, 2, 3} @@ -576,7 +574,7 @@ func TestAncestorsContracts(t *testing.T) { expected.RenewedTo = renewedTo expected.ArchivalReason = api.ContractArchivalReasonRenewed expected.StartHeight = uint64(len(fcids) - 2 - i) - expected.Spending = api.ContractSpending{} + expected.CreatedAt = contracts[i].CreatedAt if !reflect.DeepEqual(contracts[i], expected) { t.Log(cmp.Diff(contracts[i], expected)) t.Fatal("wrong contract", i, contracts[i]) @@ -888,9 +886,11 @@ func TestSQLMetadataStore(t *testing.T) { if !reflect.DeepEqual(slab2, expectedObjSlab2) { t.Fatal("mismatch", cmp.Diff(slab2, expectedObjSlab2)) } + expectedContract1.CreatedAt = contract1.CreatedAt if !reflect.DeepEqual(contract1, expectedContract1) { t.Fatal("mismatch", cmp.Diff(contract1, expectedContract1)) } + expectedContract2.CreatedAt = contract2.CreatedAt if !reflect.DeepEqual(contract2, expectedContract2) { t.Fatal("mismatch", cmp.Diff(contract2, expectedContract2)) } @@ -3447,6 +3447,7 @@ func TestDeleteHostSector(t *testing.T) { t.Fatalf("expected slab id to be %v, got %v", slabID, sectors[0].SlabID) } } + func newTestShards(hk types.PublicKey, fcid types.FileContractID, root types.Hash256) []object.Sector { return []object.Sector{ newTestShard(hk, fcid, root), @@ -4454,3 +4455,93 @@ func TestDirectories(t *testing.T) { t.Fatal("expected 1 dir, got", n) } } + +func TestPutContract(t *testing.T) { + ss := newTestSQLStore(t, defaultTestSQLStoreConfig) + defer ss.Close() + + hk := types.PublicKey{1} + if err := ss.addTestHost(hk); err != nil { + t.Fatal(err) + } + + c := api.ContractMetadata{ + CreatedAt: time.Now(), + ID: types.FileContractID{1}, + HostKey: hk, + + ProofHeight: 2, + RenewedFrom: types.FileContractID{3}, + RevisionHeight: 4, + RevisionNumber: 5, + Size: 6, + StartHeight: 7, + State: api.ContractStateComplete, + WindowStart: 8, + WindowEnd: 9, + + ContractPrice: types.NewCurrency64(10), + InitialRenterFunds: types.NewCurrency64(11), + Spending: api.ContractSpending{ + Deletions: types.NewCurrency64(12), + FundAccount: types.NewCurrency64(13), + SectorRoots: types.NewCurrency64(14), + Uploads: types.NewCurrency64(15), + }, + + ArchivalReason: api.ContractArchivalReasonHostPruned, + RenewedTo: types.FileContractID{16}, + } + if err := ss.PutContract(context.Background(), c); err != nil { + t.Fatal(err) + } + + // insert and assert the returned metadata is equal to the inserted metadata + if contracts, err := ss.Contracts(context.Background(), api.ContractsOpts{FilterMode: api.ContractFilterModeAll}); err != nil { + t.Fatal(err) + } else if len(contracts) != 1 { + t.Fatalf("expected 1 contract, instead got %d", len(contracts)) + } else if contracts[0].CreatedAt = c.CreatedAt; !reflect.DeepEqual(contracts[0], c) { + t.Fatalf("contracts are not equal, diff: %s", cmp.Diff(contracts[0], c)) + } + + u := api.ContractMetadata{ + CreatedAt: time.Now(), + ID: types.FileContractID{1}, + HostKey: hk, + + ProofHeight: 17, + RenewedFrom: types.FileContractID{18}, + RevisionHeight: 19, + RevisionNumber: 20, + Size: 21, + StartHeight: 22, + State: api.ContractStateFailed, + WindowStart: 23, + WindowEnd: 24, + + ContractPrice: types.NewCurrency64(25), + InitialRenterFunds: types.NewCurrency64(26), + Spending: api.ContractSpending{ + Deletions: types.NewCurrency64(27), + FundAccount: types.NewCurrency64(28), + SectorRoots: types.NewCurrency64(29), + Uploads: types.NewCurrency64(30), + }, + + ArchivalReason: api.ContractArchivalReasonRemoved, + RenewedTo: types.FileContractID{31}, + } + if err := ss.PutContract(context.Background(), u); err != nil { + t.Fatal(err) + } + + // update and assert the returned metadata is equal to the metadata + if contracts, err := ss.Contracts(context.Background(), api.ContractsOpts{FilterMode: api.ContractFilterModeAll}); err != nil { + t.Fatal(err) + } else if len(contracts) != 1 { + t.Fatalf("expected 1 contract, instead got %d", len(contracts)) + } else if contracts[0].CreatedAt = u.CreatedAt; !reflect.DeepEqual(contracts[0], u) { + t.Fatalf("contracts are not equal, diff: %s", cmp.Diff(contracts[0], u)) + } +} diff --git a/stores/sql/database.go b/stores/sql/database.go index c4b9fa4a1..c23089120 100644 --- a/stores/sql/database.go +++ b/stores/sql/database.go @@ -93,8 +93,9 @@ type ( // duplicates but can contain gaps. CompleteMultipartUpload(ctx context.Context, bucket, key, uploadID string, parts []api.MultipartCompletedPart, opts api.CompleteMultipartOptions) (string, error) - // Contract returns the metadata of the contract with the given ID or - // ErrContractNotFound. + // Contract returns the metadata of the contract with the given id, if + // the requested contract does not exist, or if it is archived, + // ErrContractNotFound is returned. Contract(ctx context.Context, id types.FileContractID) (cm api.ContractMetadata, err error) // ContractRoots returns the roots of the contract with the given ID. @@ -253,6 +254,10 @@ type ( // or slab buffer. PruneSlabs(ctx context.Context, limit int64) (int64, error) + // PutContract inserts the contract if it does not exist, otherwise it + // will overwrite all fields. + PutContract(ctx context.Context, c api.ContractMetadata) error + // RecordContractSpending records new spending for a contract RecordContractSpending(ctx context.Context, fcid types.FileContractID, revisionNumber, size uint64, newSpending api.ContractSpending) error diff --git a/stores/sql/mysql/main.go b/stores/sql/mysql/main.go index f4fd0d5ff..1c650ac8c 100644 --- a/stores/sql/mysql/main.go +++ b/stores/sql/mysql/main.go @@ -670,6 +670,62 @@ func (tx *MainDatabaseTx) PruneSlabs(ctx context.Context, limit int64) (int64, e return res.RowsAffected() } +func (tx *MainDatabaseTx) PutContract(ctx context.Context, c api.ContractMetadata) error { + // assert decorated fields are unset + if c.HostIP != "" { + return errors.New("host IP should not be set") + } else if c.ContractSets != nil { + return errors.New("contract sets should not be set") + } else if c.SiamuxAddr != "" { + return errors.New("siamux address should not be set") + } + + // validate metadata + var state ssql.ContractState + if err := state.LoadString(c.State); err != nil { + return err + } else if c.ID == (types.FileContractID{}) { + return errors.New("contract id is required") + } else if c.HostKey == (types.PublicKey{}) { + return errors.New("host key is required") + } + + // fetch host id + var hostID int64 + err := tx.QueryRow(ctx, `SELECT id FROM hosts WHERE public_key = ?`, ssql.PublicKey(c.HostKey)).Scan(&hostID) + if errors.Is(err, dsql.ErrNoRows) { + return api.ErrHostNotFound + } + + // set created at if it's not set + if c.CreatedAt.IsZero() { + c.CreatedAt = time.Now().UTC() + } + + // update contract + _, err = tx.Exec(ctx, ` +INSERT INTO contracts ( + created_at, fcid, host_id, host_key, + archival_reason, proof_height, renewed_from, renewed_to, revision_height, revision_number, size, start_height, state, window_start, window_end, + contract_price, initial_renter_funds, + delete_spending, fund_account_spending, sector_roots_spending, upload_spending +) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) +ON DUPLICATE KEY UPDATE + created_at = VALUES(created_at), fcid = VALUES(fcid), host_id = VALUES(host_id), host_key = VALUES(host_key), + archival_reason = VALUES(archival_reason), proof_height = VALUES(proof_height), renewed_from = VALUES(renewed_from), renewed_to = VALUES(renewed_to), revision_height = VALUES(revision_height), revision_number = VALUES(revision_number), size = VALUES(size), start_height = VALUES(start_height), state = VALUES(state), window_start = VALUES(window_start), window_end = VALUES(window_end), + contract_price = VALUES(contract_price), initial_renter_funds = VALUES(initial_renter_funds), + delete_spending = VALUES(delete_spending), fund_account_spending = VALUES(fund_account_spending), sector_roots_spending = VALUES(sector_roots_spending), upload_spending = VALUES(upload_spending)`, + c.CreatedAt, ssql.FileContractID(c.ID), hostID, ssql.PublicKey(c.HostKey), + ssql.NullableString(c.ArchivalReason), c.ProofHeight, ssql.FileContractID(c.RenewedFrom), ssql.FileContractID(c.RenewedTo), c.RevisionHeight, c.RevisionNumber, c.Size, c.StartHeight, state, c.WindowStart, c.WindowEnd, + ssql.Currency(c.ContractPrice), ssql.Currency(c.InitialRenterFunds), + ssql.Currency(c.Spending.Deletions), ssql.Currency(c.Spending.FundAccount), ssql.Currency(c.Spending.SectorRoots), ssql.Currency(c.Spending.Uploads), + ) + if err != nil { + return fmt.Errorf("failed to update contract: %w", err) + } + return nil +} + func (tx *MainDatabaseTx) RecordContractSpending(ctx context.Context, fcid types.FileContractID, revisionNumber, size uint64, newSpending api.ContractSpending) error { return ssql.RecordContractSpending(ctx, tx, fcid, revisionNumber, size, newSpending) } diff --git a/stores/sql/rows.go b/stores/sql/rows.go index 6bc968af2..364e61dd3 100644 --- a/stores/sql/rows.go +++ b/stores/sql/rows.go @@ -79,9 +79,10 @@ func (r *ContractRow) ContractMetadata() api.ContractMetadata { } return api.ContractMetadata{ - ID: types.FileContractID(r.FCID), - HostIP: r.NetAddress, - HostKey: types.PublicKey(r.HostKey), + CreatedAt: r.CreatedAt, + ID: types.FileContractID(r.FCID), + HostIP: r.NetAddress, + HostKey: types.PublicKey(r.HostKey), ContractPrice: types.Currency(r.ContractPrice), InitialRenterFunds: types.Currency(r.InitialRenterFunds), diff --git a/stores/sql/sqlite/main.go b/stores/sql/sqlite/main.go index ba80e0ef4..1f683ef13 100644 --- a/stores/sql/sqlite/main.go +++ b/stores/sql/sqlite/main.go @@ -680,6 +680,53 @@ func (tx *MainDatabaseTx) PruneSlabs(ctx context.Context, limit int64) (int64, e return res.RowsAffected() } +func (tx *MainDatabaseTx) PutContract(ctx context.Context, c api.ContractMetadata) error { + // validate metadata + var state ssql.ContractState + if err := state.LoadString(c.State); err != nil { + return err + } else if c.ID == (types.FileContractID{}) { + return errors.New("contract id is required") + } else if c.HostKey == (types.PublicKey{}) { + return errors.New("host key is required") + } + + // fetch host id + var hostID int64 + err := tx.QueryRow(ctx, `SELECT id FROM hosts WHERE public_key = ?`, ssql.PublicKey(c.HostKey)).Scan(&hostID) + if errors.Is(err, dsql.ErrNoRows) { + return api.ErrHostNotFound + } + + // set created at if it's not set + if c.CreatedAt.IsZero() { + c.CreatedAt = time.Now() + } + + // update contract + _, err = tx.Exec(ctx, ` +INSERT INTO contracts ( + created_at, fcid, host_id, host_key, + archival_reason, proof_height, renewed_from, renewed_to, revision_height, revision_number, size, start_height, state, window_start, window_end, + contract_price, initial_renter_funds, + delete_spending, fund_account_spending, sector_roots_spending, upload_spending +) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) +ON CONFLICT(fcid) DO UPDATE SET + created_at = EXCLUDED.created_at, fcid = EXCLUDED.fcid, host_id = EXCLUDED.host_id, host_key = EXCLUDED.host_key, + archival_reason = EXCLUDED.archival_reason, proof_height = EXCLUDED.proof_height, renewed_from = EXCLUDED.renewed_from, renewed_to = EXCLUDED.renewed_to, revision_height = EXCLUDED.revision_height, revision_number = EXCLUDED.revision_number, size = EXCLUDED.size, start_height = EXCLUDED.start_height, state = EXCLUDED.state, window_start = EXCLUDED.window_start, window_end = EXCLUDED.window_end, + contract_price = EXCLUDED.contract_price, initial_renter_funds = EXCLUDED.initial_renter_funds, + delete_spending = EXCLUDED.delete_spending, fund_account_spending = EXCLUDED.fund_account_spending, sector_roots_spending = EXCLUDED.sector_roots_spending, upload_spending = EXCLUDED.upload_spending`, + c.CreatedAt, ssql.FileContractID(c.ID), hostID, ssql.PublicKey(c.HostKey), + ssql.NullableString(c.ArchivalReason), c.ProofHeight, ssql.FileContractID(c.RenewedFrom), ssql.FileContractID(c.RenewedTo), c.RevisionHeight, c.RevisionNumber, c.Size, c.StartHeight, state, c.WindowStart, c.WindowEnd, + ssql.Currency(c.ContractPrice), ssql.Currency(c.InitialRenterFunds), + ssql.Currency(c.Spending.Deletions), ssql.Currency(c.Spending.FundAccount), ssql.Currency(c.Spending.SectorRoots), ssql.Currency(c.Spending.Uploads), + ) + if err != nil { + return fmt.Errorf("failed to update contract: %w", err) + } + return nil +} + func (tx *MainDatabaseTx) RecordContractSpending(ctx context.Context, fcid types.FileContractID, revisionNumber, size uint64, newSpending api.ContractSpending) error { return ssql.RecordContractSpending(ctx, tx, fcid, revisionNumber, size, newSpending) } diff --git a/stores/sql_test.go b/stores/sql_test.go index 5dcbe1236..ae130bedd 100644 --- a/stores/sql_test.go +++ b/stores/sql_test.go @@ -314,7 +314,7 @@ func (s *testSQLStore) addTestContracts(keys []types.PublicKey) (fcids []types.F } func (s *SQLStore) addTestContract(fcid types.FileContractID, hk types.PublicKey) (api.ContractMetadata, error) { - if err := s.AddContract(context.Background(), newTestContract(fcid, hk)); err != nil { + if err := s.InsertContract(context.Background(), newTestContract(fcid, hk)); err != nil { return api.ContractMetadata{}, err } return s.Contract(context.Background(), fcid) @@ -338,5 +338,5 @@ func (s *testSQLStore) renewTestContract(hk types.PublicKey, renewedFrom, renewe renewal := newTestContract(renewedTo, hk) renewal.StartHeight = startHeight renewal.RenewedFrom = renewedFrom - return s.RenewContract(context.Background(), renewal) + return s.AddRenewal(context.Background(), renewal) }