diff --git a/client/download_test.go b/client/download_test.go index 8dda2d6..bea4775 100644 --- a/client/download_test.go +++ b/client/download_test.go @@ -1,6 +1,11 @@ package client import ( + "archive/tar" + "bytes" + "compress/gzip" + "context" + "io" "net/http" "net/http/httptest" "strings" @@ -10,8 +15,7 @@ import ( "github.com/stretchr/testify/require" ) -// TestRead checks the database download functionality. -func TestRead(t *testing.T) { +func TestDownload(t *testing.T) { edition := metadata{ EditionID: "edition-1", Date: "2024-02-02", @@ -40,7 +44,7 @@ func TestRead(t *testing.T) { description string preserveFileTime bool server func(t *testing.T) *httptest.Server - checkResult func(t *testing.T, resp *ReadResult, err error) + checkResult func(t *testing.T, res DownloadResponse, err error) }{ { description: "successful download", @@ -82,16 +86,13 @@ func TestRead(t *testing.T) { return server }, - checkResult: func(t *testing.T, resp *ReadResult, err error) { + checkResult: func(t *testing.T, res DownloadResponse, err error) { require.NoError(t, err) - c, rerr := io.ReadAll(resp.reader) + c, rerr := io.ReadAll(res.Reader) require.NoError(t, rerr) require.Equal(t, dbContent, string(c)) - require.Equal(t, edition.EditionID, resp.EditionID) - require.Equal(t, edition.MD5, resp.OldHash) - require.Equal(t, "618dd27a10de24809ec160d6807f363f", resp.NewHash) - - require.Equal(t, lastModified, resp.ModifiedAt) + require.Equal(t, "618dd27a10de24809ec160d6807f363f", res.MD5) + require.Equal(t, lastModified, res.LastModified) }, }, { @@ -108,8 +109,7 @@ func TestRead(t *testing.T) { })) return server }, - checkResult: func(t *testing.T, resp *ReadResult, err error) { - require.Nil(t, resp) + checkResult: func(t *testing.T, _ DownloadResponse, err error) { require.Error(t, err) require.Regexp(t, "^unexpected HTTP status code", err.Error()) }, @@ -132,8 +132,7 @@ func TestRead(t *testing.T) { })) return server }, - checkResult: func(t *testing.T, resp *ReadResult, err error) { - require.Nil(t, resp) + checkResult: func(t *testing.T, _ DownloadResponse, err error) { require.Error(t, err) require.Regexp(t, "^encountered an error creating GZIP reader", err.Error()) }, @@ -166,8 +165,7 @@ func TestRead(t *testing.T) { return server }, - checkResult: func(t *testing.T, resp *ReadResult, err error) { - require.Nil(t, resp) + checkResult: func(t *testing.T, _ DownloadResponse, err error) { require.Error(t, err) require.Regexp(t, "^tar archive does not contain an mmdb file", err.Error()) }, @@ -211,8 +209,7 @@ func TestRead(t *testing.T) { return server }, - checkResult: func(t *testing.T, resp *ReadResult, err error) { - require.Nil(t, resp) + checkResult: func(t *testing.T, _ DownloadResponse, err error) { require.Error(t, err) require.Regexp(t, "^tar archive does not contain an mmdb file", err.Error()) }, @@ -220,21 +217,24 @@ func TestRead(t *testing.T) { } ctx := context.Background() + + accountID := 10 + licenseKey := "license" + for _, test := range tests { t.Run(test.description, func(t *testing.T) { server := test.server(t) defer server.Close() - r := NewHTTPReader( - server.URL, // fixed, as the server is mocked above. - 10, // fixed, as it's not valuable for the purpose of the test. - "license", // fixed, as it's not valuable for the purpose of the test. - false, // verbose - http.DefaultClient, + c, err := New( + accountID, + licenseKey, + WithEndpoint(server.URL), ) + require.NoError(t, err) - reader, err := r.get(ctx, edition.EditionID, edition.MD5) - test.checkResult(t, reader, err) + res, err := c.Download(ctx, edition.EditionID, edition.MD5) + test.checkResult(t, res, err) }) } } diff --git a/client/metadata_test.go b/client/metadata_test.go index b245801..1f2f2d0 100644 --- a/client/metadata_test.go +++ b/client/metadata_test.go @@ -64,20 +64,22 @@ func TestGetMetadata(t *testing.T) { ctx := context.Background() + accountID := 10 + licenseKey := "license" + for _, test := range tests { t.Run(test.description, func(t *testing.T) { server := test.server(t) defer server.Close() - r := NewHTTPReader( - server.URL, // fixed, as the server is mocked above. - 10, // fixed, as it's not valuable for the purpose of the test. - "license", // fixed, as it's not valuable for the purpose of the test. - false, // verbose - http.DefaultClient, + c, err := New( + accountID, + licenseKey, + WithEndpoint(server.URL), ) + require.NoError(t, err) - result, err := r.getMetadata(ctx, "edition-1") + result, err := c.getMetadata(ctx, "edition-1") test.checkResult(t, result, err) }) } diff --git a/internal/geoipupdate/database/local_file_writer_test.go b/internal/geoipupdate/database/local_file_writer_test.go index da4a357..c80797d 100644 --- a/internal/geoipupdate/database/local_file_writer_test.go +++ b/internal/geoipupdate/database/local_file_writer_test.go @@ -20,72 +20,67 @@ func TestLocalFileWriterWrite(t *testing.T) { checkErr func(require.TestingT, error, ...interface{}) preserveFileTime bool //nolint:revive // support older versions - checkTime func(require.TestingT, interface{}, interface{}, ...interface{}) - result *ReadResult + checkTime func(require.TestingT, interface{}, interface{}, ...interface{}) + editionID string + reader io.ReadCloser + newMD5 string + lastModified time.Time }{ { description: "success", checkErr: require.NoError, preserveFileTime: true, checkTime: require.Equal, - result: &ReadResult{ - reader: io.NopCloser(strings.NewReader("database content")), - EditionID: "GeoIP2-City", - OldHash: "", - NewHash: "cfa36ddc8279b5483a5aa25e9a6151f4", - ModifiedAt: testTime, - }, + editionID: "GeoIP2-City", + reader: io.NopCloser(strings.NewReader("database content")), + newMD5: "cfa36ddc8279b5483a5aa25e9a6151f4", + lastModified: testTime, }, { description: "hash does not match", checkErr: require.Error, preserveFileTime: true, checkTime: require.Equal, - result: &ReadResult{ - reader: io.NopCloser(strings.NewReader("database content")), - EditionID: "GeoIP2-City", - OldHash: "", - NewHash: "badhash", - ModifiedAt: testTime, - }, + editionID: "GeoIP2-City", + reader: io.NopCloser(strings.NewReader("database content")), + newMD5: "badhash", + lastModified: testTime, }, { description: "hash case does not matter", checkErr: require.NoError, preserveFileTime: true, checkTime: require.Equal, - result: &ReadResult{ - reader: io.NopCloser(strings.NewReader("database content")), - EditionID: "GeoIP2-City", - OldHash: "", - NewHash: "cfa36ddc8279b5483a5aa25e9a6151f4", - ModifiedAt: testTime, - }, + editionID: "GeoIP2-City", + reader: io.NopCloser(strings.NewReader("database content")), + newMD5: "cfa36ddc8279b5483a5aa25e9a6151f4", + lastModified: testTime, }, { description: "do not preserve file modification time", checkErr: require.NoError, preserveFileTime: false, checkTime: require.NotEqual, - result: &ReadResult{ - reader: io.NopCloser(strings.NewReader("database content")), - EditionID: "GeoIP2-City", - OldHash: "", - NewHash: "CFA36DDC8279B5483A5AA25E9A6151F4", - ModifiedAt: testTime, - }, + editionID: "GeoIP2-City", + reader: io.NopCloser(strings.NewReader("database content")), + newMD5: "CFA36DDC8279B5483A5AA25E9A6151F4", + lastModified: testTime, }, } for _, test := range tests { t.Run(test.description, func(t *testing.T) { tempDir := t.TempDir() - defer test.result.reader.Close() fw, err := NewLocalFileWriter(tempDir, test.preserveFileTime, false) require.NoError(t, err) - err = fw.Write(test.result) + err = fw.Write( + test.editionID, + test.reader, + test.newMD5, + test.lastModified, + ) test.checkErr(t, err) if err == nil { - database, err := os.Stat(fw.getFilePath(test.result.EditionID)) + database, err := os.Stat(fw.getFilePath(test.editionID)) require.NoError(t, err) test.checkTime(t, database.ModTime().UTC(), testTime) @@ -96,28 +91,23 @@ func TestLocalFileWriterWrite(t *testing.T) { // TestLocalFileWriterGetHash tests functionality of the LocalFileWriter.GetHash method. func TestLocalFileWriterGetHash(t *testing.T) { - result := &ReadResult{ - reader: io.NopCloser(strings.NewReader("database content")), - EditionID: "GeoIP2-City", - OldHash: "", - NewHash: "cfa36ddc8279b5483a5aa25e9a6151f4", - ModifiedAt: time.Time{}, - } + editionID := "GeoIP2-City" + reader := io.NopCloser(strings.NewReader("database content")) + newMD5 := "cfa36ddc8279b5483a5aa25e9a6151f4" + lastModified := time.Time{} tempDir := t.TempDir() - defer result.reader.Close() - fw, err := NewLocalFileWriter(tempDir, false, false) require.NoError(t, err) - err = fw.Write(result) + err = fw.Write(editionID, reader, newMD5, lastModified) require.NoError(t, err) // returns the correct hash for an existing database. - hash, err := fw.GetHash(result.EditionID) + hash, err := fw.GetHash(editionID) require.NoError(t, err) - require.Equal(t, hash, result.NewHash) + require.Equal(t, hash, newMD5) // returns a zero hash for a non existing edition. hash, err = fw.GetHash("NewEdition") diff --git a/internal/geoipupdate/geoip_updater_test.go b/internal/geoipupdate/geoip_updater_test.go index cdb67fe..0808ead 100644 --- a/internal/geoipupdate/geoip_updater_test.go +++ b/internal/geoipupdate/geoip_updater_test.go @@ -7,11 +7,13 @@ import ( "context" "encoding/json" "errors" + "io" "log" "net/http" "net/http/httptest" "os" "path/filepath" + "strings" "testing" "time" @@ -19,6 +21,7 @@ import ( "github.com/stretchr/testify/require" "golang.org/x/net/http2" + "github.com/maxmind/geoipupdate/v6/client" "github.com/maxmind/geoipupdate/v6/internal" "github.com/maxmind/geoipupdate/v6/internal/geoipupdate/database" ) @@ -28,17 +31,18 @@ import ( func TestUpdaterOutput(t *testing.T) { now := time.Now().Truncate(time.Second).In(time.UTC) testTime := time.Date(2023, 4, 27, 12, 4, 48, 0, time.UTC) - databases := []database.ReadResult{ + outputs := []client.DownloadResponse{ { - EditionID: "GeoLite2-City", - OldHash: "A", - NewHash: "B", - ModifiedAt: testTime, - }, { - EditionID: "GeoIP2-Country", - OldHash: "C", - NewHash: "D", - ModifiedAt: testTime, + LastModified: testTime, + MD5: "B", + Reader: io.NopCloser(strings.NewReader("")), + UpdateAvailable: true, + }, + { + LastModified: testTime, + MD5: "D", + Reader: io.NopCloser(strings.NewReader("")), + UpdateAvailable: true, }, } @@ -56,10 +60,17 @@ func TestUpdaterOutput(t *testing.T) { // create a fake Updater with a mocked database reader and writer. u := &Updater{ - config: config, - reader: &mockReader{i: 0, result: databases}, - output: log.New(logOutput, "", 0), - writer: &mockWriter{}, + config: config, + output: log.New(logOutput, "", 0), + updateClient: &mockUpdateClient{i: 0, outputs: outputs}, + writer: &mockWriter{ + md5s: map[string]string{ + // These are the "MD5s" that we currently have before running an + // update. + "GeoLite2-City": "A", + "GeoLite2-Country": "C", + }, + }, } err := u.Run(context.Background()) @@ -69,13 +80,29 @@ func TestUpdaterOutput(t *testing.T) { var outputDatabases []database.ReadResult err = json.Unmarshal(logOutput.Bytes(), &outputDatabases) require.NoError(t, err) - require.Equal(t, len(outputDatabases), len(databases)) - for i := 0; i < len(databases); i++ { - require.Equal(t, databases[i].EditionID, outputDatabases[i].EditionID) - require.Equal(t, databases[i].OldHash, outputDatabases[i].OldHash) - require.Equal(t, databases[i].NewHash, outputDatabases[i].NewHash) - require.Equal(t, databases[i].ModifiedAt, outputDatabases[i].ModifiedAt) + wantDatabases := []database.ReadResult{ + { + EditionID: "GeoLite2-City", + OldHash: "A", + NewHash: "B", + ModifiedAt: testTime, + }, + { + EditionID: "GeoLite2-Country", + OldHash: "C", + NewHash: "D", + ModifiedAt: testTime, + }, + } + + require.Equal(t, len(wantDatabases), len(outputDatabases)) + + for i := 0; i < len(wantDatabases); i++ { + require.Equal(t, wantDatabases[i].EditionID, outputDatabases[i].EditionID) + require.Equal(t, wantDatabases[i].OldHash, outputDatabases[i].OldHash) + require.Equal(t, wantDatabases[i].NewHash, outputDatabases[i].NewHash) + require.Equal(t, wantDatabases[i].ModifiedAt, outputDatabases[i].ModifiedAt) // comparing time wasn't supported with require in older go versions. if !afterOrEqual(outputDatabases[i].CheckedAt, now) { t.Errorf("database %s was not updated", outputDatabases[i].EditionID) @@ -84,13 +111,13 @@ func TestUpdaterOutput(t *testing.T) { // Test with a write error. - u.reader.(*mockReader).i = 0 + u.updateClient.(*mockUpdateClient).i = 0 streamErr := http2.StreamError{ Code: http2.ErrCodeInternal, } u.writer = &mockWriter{ - WriteFunc: func(_ *database.ReadResult) error { + writeFunc: func(_ string, _ io.ReadCloser, _ string, _ time.Time) error { return streamErr }, } @@ -159,8 +186,10 @@ func TestRetryWhenWriting(t *testing.T) { defer sv.Close() config := &Config{ + AccountID: 10, URL: sv.URL, EditionIDs: []string{"foo-db-name"}, + LicenseKey: "foo", LockFile: filepath.Join(tempDir, ".geoipupdate.lock"), Output: true, Parallelism: 1, @@ -170,6 +199,13 @@ func TestRetryWhenWriting(t *testing.T) { logOutput := &bytes.Buffer{} + updateClient, err := client.New( + config.AccountID, + config.LicenseKey, + client.WithEndpoint(config.URL), + ) + require.NoError(t, err) + writer, err := database.NewLocalFileWriter( config.DatabaseDirectory, config.PreserveFileTimes, @@ -178,16 +214,10 @@ func TestRetryWhenWriting(t *testing.T) { require.NoError(t, err) u := &Updater{ - config: config, - reader: database.NewHTTPReader( - config.URL, - config.AccountID, - config.LicenseKey, - config.Verbose, - http.DefaultClient, - ), - output: log.New(logOutput, "", 0), - writer: writer, + config: config, + output: log.New(logOutput, "", 0), + updateClient: updateClient, + writer: writer, } ctx := context.Background() @@ -197,7 +227,7 @@ func TestRetryWhenWriting(t *testing.T) { _, err = u.downloadEdition( ctx, "foo-db-name", - u.reader, + u.updateClient, u.writer, ) @@ -211,32 +241,45 @@ func TestRetryWhenWriting(t *testing.T) { assert.Empty(t, logOutput.String()) } -type mockReader struct { - i int - result []database.ReadResult +type mockUpdateClient struct { + i int + outputs []client.DownloadResponse } -func (mr *mockReader) Read(_ context.Context, _, _ string) (*database.ReadResult, error) { - if mr.i >= len(mr.result) { - return nil, errors.New("out of bounds") +func (m *mockUpdateClient) Download( + _ context.Context, + _, + _ string, +) (client.DownloadResponse, error) { + if m.i >= len(m.outputs) { + return client.DownloadResponse{}, errors.New("out of bounds") } - res := mr.result[mr.i] - mr.i++ - return &res, nil + res := m.outputs[m.i] + m.i++ + return res, nil } type mockWriter struct { - WriteFunc func(*database.ReadResult) error + md5s map[string]string + writeFunc func(string, io.ReadCloser, string, time.Time) error } -func (w *mockWriter) Write(r *database.ReadResult) error { - if w.WriteFunc != nil { - return w.WriteFunc(r) +func (w *mockWriter) Write( + editionID string, + reader io.ReadCloser, + md5 string, + lastModified time.Time, +) error { + if w.writeFunc != nil { + return w.writeFunc(editionID, reader, md5, lastModified) } return nil } -func (w mockWriter) GetHash(_ string) (string, error) { return "", nil } + +func (w mockWriter) GetHash(editionID string) (string, error) { + return w.md5s[editionID], nil +} func afterOrEqual(t1, t2 time.Time) bool { return t1.After(t2) || t1.Equal(t2)