diff --git a/.github/workflows/pr-checks.yaml b/.github/workflows/pr-checks.yaml index 00f22afd..b2074c65 100644 --- a/.github/workflows/pr-checks.yaml +++ b/.github/workflows/pr-checks.yaml @@ -71,4 +71,4 @@ jobs: go-version-file: go.mod - name: Run tests - run: go test -v -race $(go list ./... | grep -v 'tests/integration' | grep -v 'client/indexd/tests') + run: go test -v -race $( go list ./... | grep -v tests/integration ) diff --git a/client/anvil/anvil_client.go b/client/anvil/anvil_client.go index 18d5501a..387bdaba 100644 --- a/client/anvil/anvil_client.go +++ b/client/anvil/anvil_client.go @@ -10,8 +10,8 @@ import ( "time" "github.com/bytedance/sonic" - drs "github.com/calypr/data-client/indexd/drs" - hash "github.com/calypr/data-client/indexd/hash" + drs "github.com/calypr/data-client/drs" + hash "github.com/calypr/data-client/hash" "golang.org/x/oauth2/google" ) diff --git a/client/indexd/add_url.go b/client/indexd/add_url.go index e783aa5b..fdf3261d 100644 --- a/client/indexd/add_url.go +++ b/client/indexd/add_url.go @@ -12,17 +12,17 @@ import ( awsConfig "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/credentials" "github.com/aws/aws-sdk-go-v2/service/s3" + "github.com/calypr/data-client/drs" "github.com/calypr/data-client/fence" + "github.com/calypr/data-client/hash" "github.com/calypr/data-client/indexd" - "github.com/calypr/data-client/indexd/drs" - "github.com/calypr/data-client/indexd/hash" + "github.com/calypr/data-client/s3utils" + "github.com/calypr/git-drs/cloud" "github.com/calypr/git-drs/common" "github.com/calypr/git-drs/drslog" "github.com/calypr/git-drs/drsmap" "github.com/calypr/git-drs/lfs" "github.com/calypr/git-drs/messages" - "github.com/calypr/git-drs/s3_utils" - "github.com/calypr/git-drs/utils" ) // getBucketDetails fetches bucket details from Gen3 using data-client. @@ -32,7 +32,7 @@ func (inc *GitDrsIdxdClient) getBucketDetails(ctx context.Context, bucket string // FetchS3MetadataWithBucketDetails fetches S3 metadata given bucket details. func FetchS3MetadataWithBucketDetails(ctx context.Context, s3URL, awsAccessKey, awsSecretKey, region, endpoint string, bucketDetails *fence.S3Bucket, s3Client *s3.Client, logger *slog.Logger) (int64, string, error) { - bucket, key, err := utils.ParseS3URL(s3URL) + bucket, key, err := cloud.ParseS3URL(s3URL) if err != nil { return 0, "", fmt.Errorf("failed to parse S3 URL: %w", err) } @@ -101,7 +101,7 @@ func FetchS3MetadataWithBucketDetails(ctx context.Context, s3URL, awsAccessKey, } func (inc *GitDrsIdxdClient) fetchS3Metadata(ctx context.Context, s3URL, awsAccessKey, awsSecretKey, region, endpoint string, s3Client *s3.Client, httpClient *http.Client, logger *slog.Logger) (int64, string, error) { - bucket, _, err := utils.ParseS3URL(s3URL) + bucket, _, err := cloud.ParseS3URL(s3URL) if err != nil { return 0, "", fmt.Errorf("failed to parse S3 URL: %w", err) } @@ -145,7 +145,7 @@ func (inc *GitDrsIdxdClient) upsertIndexdRecord(ctx context.Context, url string, // If no record exists, create one logger.Debug("creating new record") - _, relPath, _ := utils.ParseS3URL(url) + _, relPath, _ := cloud.ParseS3URL(url) drsObj, err := drs.BuildDrsObj(relPath, sha256, fileSize, uuid, inc.Config.BucketName, projectId) if err != nil { @@ -157,11 +157,11 @@ func (inc *GitDrsIdxdClient) upsertIndexdRecord(ctx context.Context, url string, return inc.RegisterRecord(ctx, drsObj) } -func (inc *GitDrsIdxdClient) AddURL(s3URL, sha256, awsAccessKey, awsSecretKey, regionFlag, endpointFlag string, opts ...s3_utils.AddURLOption) (s3_utils.S3Meta, error) { +func (inc *GitDrsIdxdClient) AddURL(s3URL, sha256, awsAccessKey, awsSecretKey, regionFlag, endpointFlag string, opts ...cloud.AddURLOption) (s3utils.S3Meta, error) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - cfg := &s3_utils.AddURLConfig{} + cfg := &cloud.AddURLConfig{} for _, opt := range opts { opt(cfg) } @@ -170,27 +170,27 @@ func (inc *GitDrsIdxdClient) AddURL(s3URL, sha256, awsAccessKey, awsSecretKey, r inc.Logger = drslog.NewNoOpLogger() } - if err := s3_utils.ValidateInputs(s3URL, sha256); err != nil { - return s3_utils.S3Meta{}, err + if err := s3utils.ValidateInputs(s3URL, sha256); err != nil { + return s3utils.S3Meta{}, err } - _, relPath, err := utils.ParseS3URL(s3URL) + _, relPath, err := cloud.ParseS3URL(s3URL) if err != nil { - return s3_utils.S3Meta{}, fmt.Errorf("failed to parse S3 URL: %w", err) + return s3utils.S3Meta{}, fmt.Errorf("failed to parse S3 URL: %w", err) } - isLFS, err := lfs.IsLFSTracked(".gitattributes", relPath) + isLFS, err := lfs.IsLFSTracked(relPath) if err != nil { - return s3_utils.S3Meta{}, fmt.Errorf("unable to determine if file is tracked by LFS: %w", err) + return s3utils.S3Meta{}, fmt.Errorf("unable to determine if file is tracked by LFS: %w", err) } if !isLFS { - return s3_utils.S3Meta{}, fmt.Errorf("file is not tracked by LFS") + return s3utils.S3Meta{}, fmt.Errorf("file is not tracked by LFS") } inc.Logger.Debug("Fetching S3 metadata...") fileSize, modifiedDate, err := inc.fetchS3Metadata(ctx, s3URL, awsAccessKey, awsSecretKey, regionFlag, endpointFlag, cfg.S3Client, cfg.HttpClient, inc.Logger) if err != nil { - return s3_utils.S3Meta{}, fmt.Errorf("failed to fetch S3 metadata: %w", err) + return s3utils.S3Meta{}, fmt.Errorf("failed to fetch S3 metadata: %w", err) } inc.Logger.Debug(fmt.Sprintf("Fetched S3 metadata successfully: %d bytes, modified: %s", fileSize, modifiedDate)) @@ -198,20 +198,20 @@ func (inc *GitDrsIdxdClient) AddURL(s3URL, sha256, awsAccessKey, awsSecretKey, r inc.Logger.Debug("Processing indexd record...") drsObj, err := inc.upsertIndexdRecord(ctx, s3URL, sha256, fileSize, inc.Logger) if err != nil { - return s3_utils.S3Meta{}, fmt.Errorf("failed to create indexd record: %w", err) + return s3utils.S3Meta{}, fmt.Errorf("failed to create indexd record: %w", err) } drsObjPath, err := drsmap.GetObjectPath(common.DRS_OBJS_PATH, drsObj.Checksums.SHA256) if err != nil { - return s3_utils.S3Meta{}, err + return s3utils.S3Meta{}, err } if err := drsmap.WriteDrsObj(drsObj, sha256, drsObjPath); err != nil { - return s3_utils.S3Meta{}, err + return s3utils.S3Meta{}, err } inc.Logger.Debug("Indexd updated") - return s3_utils.S3Meta{ + return s3utils.S3Meta{ Size: fileSize, LastModified: modifiedDate, }, nil diff --git a/client/indexd/client.go b/client/indexd/client.go index d342fe7c..a747649a 100644 --- a/client/indexd/client.go +++ b/client/indexd/client.go @@ -8,9 +8,9 @@ import ( "github.com/calypr/data-client/common" "github.com/calypr/data-client/conf" + "github.com/calypr/data-client/drs" "github.com/calypr/data-client/g3client" - "github.com/calypr/data-client/indexd/drs" - "github.com/calypr/data-client/indexd/hash" + "github.com/calypr/data-client/hash" "github.com/calypr/data-client/logs" "github.com/calypr/git-drs/client" "github.com/calypr/git-drs/drsmap" diff --git a/client/indexd/client_test.go b/client/indexd/client_test.go index 81f34537..77b5ac34 100644 --- a/client/indexd/client_test.go +++ b/client/indexd/client_test.go @@ -15,10 +15,10 @@ import ( "github.com/bytedance/sonic/encoder" "github.com/calypr/data-client/common" "github.com/calypr/data-client/conf" + "github.com/calypr/data-client/drs" "github.com/calypr/data-client/g3client" + "github.com/calypr/data-client/hash" "github.com/calypr/data-client/indexd" - "github.com/calypr/data-client/indexd/drs" - "github.com/calypr/data-client/indexd/hash" "github.com/calypr/data-client/logs" ) diff --git a/client/indexd/register.go b/client/indexd/register.go index 0e0f1602..3fd5f62a 100644 --- a/client/indexd/register.go +++ b/client/indexd/register.go @@ -7,7 +7,7 @@ import ( "strings" "github.com/calypr/data-client/common" - "github.com/calypr/data-client/indexd/drs" + "github.com/calypr/data-client/drs" "github.com/calypr/data-client/upload" "github.com/calypr/git-drs/drsmap" ) diff --git a/client/interface.go b/client/interface.go index 7173a2d3..d28962c6 100644 --- a/client/interface.go +++ b/client/interface.go @@ -3,10 +3,11 @@ package client import ( "context" + drs "github.com/calypr/data-client/drs" dataClient "github.com/calypr/data-client/g3client" - drs "github.com/calypr/data-client/indexd/drs" - hash "github.com/calypr/data-client/indexd/hash" - "github.com/calypr/git-drs/s3_utils" + hash "github.com/calypr/data-client/hash" + "github.com/calypr/data-client/s3utils" + "github.com/calypr/git-drs/cloud" ) type DRSClient interface { @@ -59,7 +60,7 @@ type DRSClient interface { BuildDrsObj(fileName string, checksum string, size int64, drsId string) (*drs.DRSObject, error) // Add an S3 URL to an existing indexd record - AddURL(s3URL, sha256, awsAccessKey, awsSecretKey, regionFlag, endpointFlag string, opts ...s3_utils.AddURLOption) (s3_utils.S3Meta, error) + AddURL(s3URL, sha256, awsAccessKey, awsSecretKey, regionFlag, endpointFlag string, opts ...cloud.AddURLOption) (s3utils.S3Meta, error) GetBucketName() string diff --git a/client/tests/add-url-helper_test.go b/client/tests/add-url-helper_test.go index b339e63d..258b1950 100644 --- a/client/tests/add-url-helper_test.go +++ b/client/tests/add-url-helper_test.go @@ -6,9 +6,8 @@ import ( "testing" "github.com/bytedance/sonic/encoder" + "github.com/calypr/git-drs/cloud" "github.com/calypr/git-drs/drsmap" - "github.com/calypr/git-drs/s3_utils" - "github.com/calypr/git-drs/utils" ) // TestParseS3URL_Valid tests parsing valid S3 URLs @@ -52,7 +51,7 @@ func TestParseS3URL_Valid(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - bucket, key, err := utils.ParseS3URL(tt.s3URL) + bucket, key, err := cloud.ParseS3URL(tt.s3URL) if (err != nil) != tt.wantErr { t.Errorf("ParseS3URL() error = %v, wantErr %v", err, tt.wantErr) return @@ -106,7 +105,7 @@ func TestParseS3URL_Invalid(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - _, _, err := utils.ParseS3URL(tt.s3URL) + _, _, err := cloud.ParseS3URL(tt.s3URL) if (err != nil) != tt.wantErr { t.Errorf("ParseS3URL() error = %v, wantErr %v", err, tt.wantErr) } @@ -149,8 +148,8 @@ func TestGetBucketDetails_Gen3Success(t *testing.T) { return } - response := s3_utils.S3BucketsResponse{ - S3Buckets: map[string]*s3_utils.S3Bucket{ + response := cloud.S3BucketsResponse{ + S3Buckets: map[string]*cloud.S3Bucket{ "test-bucket": { Region: "us-west-2", EndpointURL: "https://s3.aws.amazon.com", diff --git a/client/tests/add-url_test.go b/client/tests/add-url_test.go index 21b16295..a1979bf7 100644 --- a/client/tests/add-url_test.go +++ b/client/tests/add-url_test.go @@ -3,7 +3,7 @@ package indexd_tests import ( "testing" - "github.com/calypr/git-drs/s3_utils" + "github.com/calypr/git-drs/cloud" ) // TestValidateInputs_ValidInputs tests validation with valid S3 URL and SHA256 @@ -36,7 +36,7 @@ func TestValidateInputs_ValidInputs(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := s3_utils.ValidateInputs(tt.s3URL, tt.sha256) + err := cloud.ValidateInputs(tt.s3URL, tt.sha256) if (err != nil) != tt.wantErr { t.Errorf("validateInputs() error = %v, wantErr %v", err, tt.wantErr) } @@ -82,7 +82,7 @@ func TestValidateInputs_InvalidS3URL(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := s3_utils.ValidateInputs(tt.s3URL, validSHA256) + err := cloud.ValidateInputs(tt.s3URL, validSHA256) if (err != nil) != tt.wantErr { t.Errorf("validateInputs() error = %v, wantErr %v", err, tt.wantErr) } @@ -133,7 +133,7 @@ func TestValidateInputs_InvalidSHA256(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := s3_utils.ValidateInputs(validS3URL, tt.sha256) + err := cloud.ValidateInputs(validS3URL, tt.sha256) if (err != nil) != tt.wantErr { t.Errorf("validateInputs() error = %v, wantErr %v", err, tt.wantErr) } @@ -147,7 +147,7 @@ func TestValidateInputs_SHA256Normalization(t *testing.T) { uppercaseSHA256 := "E3B0C44298FC1C149AFBF4C8996FB92427AE41E4649B934CA495991B7852B855" // Should not error on uppercase SHA256 (it gets normalized internally) - err := s3_utils.ValidateInputs(validS3URL, uppercaseSHA256) + err := cloud.ValidateInputs(validS3URL, uppercaseSHA256) if err != nil { t.Errorf("validateInputs() should accept uppercase SHA256, got error: %v", err) } @@ -159,7 +159,7 @@ func TestValidateInputs_HexDecodeValidation(t *testing.T) { // Test valid 64-character hex string validHex := "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" - err := s3_utils.ValidateInputs(validS3URL, validHex) + err := cloud.ValidateInputs(validS3URL, validHex) if err != nil { t.Errorf("validateInputs() error = %v, want nil", err) } @@ -167,7 +167,7 @@ func TestValidateInputs_HexDecodeValidation(t *testing.T) { // Test that hex.DecodeString is properly checked // This has correct length but invalid hex invalidHex := "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" - err = s3_utils.ValidateInputs(validS3URL, invalidHex) + err = cloud.ValidateInputs(validS3URL, invalidHex) if err == nil { t.Errorf("validateInputs() should reject invalid hex, got nil error") } @@ -201,7 +201,7 @@ func TestValidateInputs_CaseSensitivity(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := s3_utils.ValidateInputs(tt.s3URL, validSHA256) + err := cloud.ValidateInputs(tt.s3URL, validSHA256) if (err != nil) != tt.wantErr { t.Errorf("validateInputs() error = %v, wantErr %v", err, tt.wantErr) } @@ -245,7 +245,7 @@ func TestValidateInputs_EdgeCases(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := s3_utils.ValidateInputs(tt.s3URL, tt.sha256) + err := cloud.ValidateInputs(tt.s3URL, tt.sha256) if (err != nil) != tt.wantErr { t.Errorf("validateInputs() error = %v, wantErr %v", err, tt.wantErr) } @@ -260,6 +260,6 @@ func BenchmarkValidateInputs(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - _ = s3_utils.ValidateInputs(s3URL, sha256) + _ = cloud.ValidateInputs(s3URL, sha256) } } diff --git a/cloud/agent_fetch_reader.go b/cloud/agent_fetch_reader.go new file mode 100644 index 00000000..b6e094a4 --- /dev/null +++ b/cloud/agent_fetch_reader.go @@ -0,0 +1,192 @@ +package cloud + +import ( + "context" + "fmt" + "io" + "net/http" + "net/url" + "os" + "strings" + "sync/atomic" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/s3" +) + +// progressReader wraps an io.ReadCloser and periodically writes progress to stderr. +type progressReader struct { + rc io.ReadCloser + label string + start time.Time + total int64 // accessed atomically + quit chan struct{} + done chan struct{} + ticker time.Duration +} + +func newProgressReader(rc io.ReadCloser, label string) io.ReadCloser { + p := &progressReader{ + rc: rc, + label: label, + start: time.Now(), + quit: make(chan struct{}), + done: make(chan struct{}), + ticker: 500 * time.Millisecond, + } + + go func() { + t := time.NewTicker(p.ticker) + defer t.Stop() + var last int64 + for { + select { + case <-t.C: + total := atomic.LoadInt64(&p.total) + elapsed := time.Since(p.start).Seconds() + var rate float64 + if elapsed > 0 { + rate = float64(total) / elapsed + } + // \r to overwrite the same line like git pull; no newline until done + fmt.Fprintf(os.Stderr, "\r%s: %d bytes (%.1f KiB/s)", p.label, total, rate/1024) + last = total + case <-p.quit: + // final line (replace same line, then newline) + total := atomic.LoadInt64(&p.total) + _ = last // in case we want to use last for something later + fmt.Fprintf(os.Stderr, "\r%s: %d bytes\n", p.label, total) + close(p.done) + return + } + } + }() + + return p +} + +func (p *progressReader) Read(b []byte) (int, error) { + n, err := p.rc.Read(b) + if n > 0 { + atomic.AddInt64(&p.total, int64(n)) + } + return n, err +} + +func (p *progressReader) Close() error { + // Close underlying reader first, then stop progress goroutine and wait for completion. + err := p.rc.Close() + close(p.quit) + <-p.done + return err +} + +// AgentFetchReader fetches the object described by `input` and returns an io.ReadCloser. +// It accepts `s3://bucket/key` URLs and converts them to HTTPS URLs. If `input.AWSEndpoint` +// is set it will use that endpoint in path-style (endpoint/bucket/key); otherwise it +// uses the default virtual-hosted AWS form: https://{bucket}.s3.amazonaws.com/{key}. +func AgentFetchReader(ctx context.Context, input S3ObjectParameters) (io.ReadCloser, error) { + if ctx == nil { + ctx = context.Background() + } + + raw := strings.TrimSpace(input.S3URL) + if raw == "" { + return nil, fmt.Errorf("AgentFetchReader: S3ObjectParameters.S3URL is empty") + } + + useSignedFetch := strings.TrimSpace(input.AWSAccessKey) != "" || + strings.TrimSpace(input.AWSSecretKey) != "" || + strings.TrimSpace(input.AWSRegion) != "" + if useSignedFetch { + if strings.TrimSpace(input.AWSAccessKey) == "" || strings.TrimSpace(input.AWSSecretKey) == "" || strings.TrimSpace(input.AWSRegion) == "" { + return nil, fmt.Errorf("AgentFetchReader: AWSAccessKey, AWSSecretKey, and AWSRegion are required for signed fetch") + } + + bucket, key, err := parseS3URL(raw) + if err != nil { + return nil, fmt.Errorf("AgentFetchReader: parse s3 url %q: %w", raw, err) + } + + s3Client, err := newS3Client(ctx, input) + if err != nil { + return nil, fmt.Errorf("AgentFetchReader: init s3 client: %w", err) + } + + out, err := s3Client.GetObject(ctx, &s3.GetObjectInput{ + Bucket: aws.String(bucket), + Key: aws.String(key), + }) + if err != nil { + return nil, fmt.Errorf("AgentFetchReader: s3 GetObject failed (bucket=%q key=%q): %w", bucket, key, err) + } + if out.Body == nil { + return nil, fmt.Errorf("AgentFetchReader: response body is nil for s3://%s/%s", bucket, key) + } + + label := fmt.Sprintf("fetch s3://%s/%s", bucket, key) + return newProgressReader(out.Body, label), nil + } + + u, err := url.Parse(raw) + if err != nil { + return nil, fmt.Errorf("AgentFetchReader: parse url %q: %w", raw, err) + } + + var s3url string + switch u.Scheme { + case "s3": + bucket := u.Host + key := strings.TrimPrefix(u.Path, "/") + if bucket == "" || key == "" { + return nil, fmt.Errorf("AgentFetchReader: invalid s3 URL %q", raw) + } + if ep := strings.TrimSpace(input.AWSEndpoint); ep != "" { + // ensure endpoint has a scheme + if !strings.HasPrefix(ep, "http://") && !strings.HasPrefix(ep, "https://") { + ep = "https://" + ep + } + s3url = strings.TrimRight(ep, "/") + "/" + bucket + "/" + key + } else { + s3url = fmt.Sprintf("https://%s.s3.amazonaws.com/%s", bucket, key) + } + case "", "http", "https": + // allow bare host/path (no scheme) by assuming https, otherwise use as-is + if u.Scheme == "" { + s3url = "https://" + raw + } else { + s3url = raw + } + default: + return nil, fmt.Errorf("AgentFetchReader: unsupported URL scheme %q", u.Scheme) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, s3url, nil) + if err != nil { + return nil, fmt.Errorf("AgentFetchReader: create request: %w", err) + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, fmt.Errorf("AgentFetchReader: http get %s: %w", s3url, err) + } + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + if _, copyErr := io.Copy(io.Discard, resp.Body); copyErr != nil { + _ = resp.Body.Close() + return nil, fmt.Errorf("AgentFetchReader: unexpected status %d fetching %s; failed to drain body: %w", resp.StatusCode, s3url, copyErr) + } + if closeErr := resp.Body.Close(); closeErr != nil { + return nil, fmt.Errorf("AgentFetchReader: unexpected status %d fetching %s; failed to close body: %w", resp.StatusCode, s3url, closeErr) + } + return nil, fmt.Errorf("AgentFetchReader: unexpected status %d fetching %s", resp.StatusCode, s3url) + } + + if resp.Body == nil { + return nil, fmt.Errorf("AgentFetchReader: response body is nil for %s", s3url) + } + // wrap response body with progress reporting that writes to stderr + label := fmt.Sprintf("fetch %s", s3url) + return newProgressReader(resp.Body, label), nil +} diff --git a/cloud/download.go b/cloud/download.go new file mode 100644 index 00000000..e3b057ef --- /dev/null +++ b/cloud/download.go @@ -0,0 +1,49 @@ +package cloud + +import ( + "fmt" + "io" + "net/http" + "os" + "path/filepath" +) + +// downloads a file to a specified path using a signed URL +func DownloadSignedUrl(signedURL string, dstPath string) error { + // Download the file using the signed URL + fileResponse, err := http.Get(signedURL) + if err != nil { + return err + } + defer fileResponse.Body.Close() + + // Check if the response status is OK + if fileResponse.StatusCode != http.StatusOK { + body, err := io.ReadAll(fileResponse.Body) + if err != nil { + return fmt.Errorf("failed to download file using signed URL: %s", fileResponse.Status) + } + return fmt.Errorf("failed to download file using signed URL: %s. Full error: %s", fileResponse.Status, string(body)) + } + + // Create the destination directory if it doesn't exist + err = os.MkdirAll(filepath.Dir(dstPath), os.ModePerm) + if err != nil { + return err + } + + // Create the destination file + dstFile, err := os.Create(dstPath) + if err != nil { + return err + } + defer dstFile.Close() + + // Write the file content to the destination file + _, err = io.Copy(dstFile, fileResponse.Body) + if err != nil { + return err + } + + return nil +} diff --git a/cloud/download_test.go b/cloud/download_test.go new file mode 100644 index 00000000..bb13bc45 --- /dev/null +++ b/cloud/download_test.go @@ -0,0 +1,42 @@ +package cloud + +import ( + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" +) + +func TestDownloadSignedUrl(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("content")) + })) + defer server.Close() + + dst := filepath.Join(t.TempDir(), "file.txt") + if err := DownloadSignedUrl(server.URL, dst); err != nil { + t.Fatalf("DownloadSignedUrl error: %v", err) + } + data, err := os.ReadFile(dst) + if err != nil { + t.Fatalf("read file: %v", err) + } + if string(data) != "content" { + t.Fatalf("unexpected file content: %s", string(data)) + } +} + +func TestDownloadSignedUrl_Error(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte("bad")) + })) + defer server.Close() + + dst := filepath.Join(t.TempDir(), "file.txt") + if err := DownloadSignedUrl(server.URL, dst); err == nil { + t.Fatalf("expected error") + } +} diff --git a/cloud/downloader.go b/cloud/downloader.go new file mode 100644 index 00000000..19c0703f --- /dev/null +++ b/cloud/downloader.go @@ -0,0 +1,93 @@ +package cloud + +import ( + "context" + "crypto/sha256" + "fmt" + "io" + "os" + "path/filepath" +) + +// Download downloads the S3 object to a temporary file while computing its SHA256 hash. +// returns the computed SHA256 hash, temporary path and any error encountered. +func Download(ctx context.Context, info *S3Object, s3Input S3ObjectParameters, lfsRoot string) (string, string, error) { + // 2) object destination + etag := info.ETag + subdir1, subdir2 := "xx", "yy" + if len(etag) >= 4 { + subdir1 = etag[0:2] + subdir2 = etag[2:4] + } + objName := etag + if objName == "" { + objName = "unknown-etag" + } + tmpDir := filepath.Join(lfsRoot, "tmp-objects", subdir1, subdir2) + tmpObj := filepath.Join(tmpDir, objName) + + // 3) fetch bytes -> tmp, compute sha+count + + // Create the temporary directory and file where the S3 object will be streamed while computing its hash and size. + if err := os.MkdirAll(tmpDir, 0755); err != nil { + return "", "", fmt.Errorf("mkdir %s: %w", tmpDir, err) + } + + f, err := os.Create(tmpObj) + if err != nil { + return "", "", fmt.Errorf("create %s: %w", tmpObj, err) + } + // ensure any leftover file is closed and error propagated via named return + defer func() { + if f != nil { + if cerr := f.Close(); cerr != nil && err == nil { + err = fmt.Errorf("close tmp file: %w", cerr) + } + } + }() + + h := sha256.New() + + var reader io.ReadCloser + reader, err = AgentFetchReader(ctx, s3Input) + if err != nil { + return "", "", fmt.Errorf("fetch reader: %w", err) + } + // ensure close on any early return; propagate close error via named return + defer func() { + if reader != nil { + if cerr := reader.Close(); cerr != nil && err == nil { + err = fmt.Errorf("close reader: %w", cerr) + } + } + }() + + n, err := io.Copy(io.MultiWriter(f, h), reader) + if err != nil { + return "", "", fmt.Errorf("copy bytes to %s: %w", tmpObj, err) + } + + // explicitly close reader and handle error + if cerr := reader.Close(); cerr != nil { + return "", "", fmt.Errorf("close reader: %w", cerr) + } + reader = nil + + // ensure data is flushed to disk + if err = f.Sync(); err != nil { + return "", "", fmt.Errorf("sync %s: %w", tmpObj, err) + } + + // explicitly close tmp file before rename + if cerr := f.Close(); cerr != nil { + return "", "", fmt.Errorf("close %s: %w", tmpObj, cerr) + } + f = nil + + // use n (bytes written) to avoid unused var warnings + _ = n + + // compute hex SHA256 of the fetched content + computedSHA := fmt.Sprintf("%x", h.Sum(nil)) + return computedSHA, tmpObj, nil +} diff --git a/cloud/inspect.go b/cloud/inspect.go new file mode 100644 index 00000000..07351088 --- /dev/null +++ b/cloud/inspect.go @@ -0,0 +1,261 @@ +// Package lfss3 provides a small helper for Git-LFS + S3 object introspection. +// +// It: +// 1. determines the effective Git LFS storage root (.git/lfs vs git config lfs.storage) +// 2. derives a working-tree filename from the S3 object key (basename of key) +// 3. performs an S3 HEAD Object to retrieve size and user metadata (sha256 if present) +package cloud + +import ( + "context" + "crypto/tls" + "errors" + "fmt" + "net/http" + "net/url" + "path" + "regexp" + "strings" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/s3" +) + +// S3ObjectParameters container for S3 object identification and access. +type S3ObjectParameters struct { + S3URL string + AWSAccessKey string + AWSSecretKey string + AWSRegion string + AWSEndpoint string // optional: custom endpoint (Ceph/MinIO/etc.) + SHA256 string // optional expected hex (64 chars). Can be "sha256:" or "" + DestinationPath string // optional override URL path (worktree filename) +} + +// S3Object is what we return. +type S3Object struct { + + // Object identity + Bucket string + Key string + Path string // basename of Key (filename), or override from input + + // HEAD-derived info + SizeBytes int64 + MetaSHA256 string // from user-defined object metadata (if present) + ETag string + LastModTime time.Time +} + +// InspectS3ForLFS does all 3 requested tasks. +func InspectS3ForLFS(ctx context.Context, in S3ObjectParameters) (*S3Object, error) { + if strings.TrimSpace(in.S3URL) == "" { + return nil, errors.New("S3URL is required") + } + if strings.TrimSpace(in.AWSRegion) == "" { + return nil, errors.New("AWSRegion is required") + } + if in.AWSAccessKey == "" || in.AWSSecretKey == "" { + return nil, errors.New("AWSAccessKey and AWSSecretKey are required") + } + + // 2) Parse S3 URL + derive working tree filename. + bucket, key, err := parseS3URL(in.S3URL) + if err != nil { + return nil, err + } + worktreeName := strings.TrimSpace(in.DestinationPath) + if worktreeName == "" { + worktreeName = path.Base(key) + if worktreeName == "." || worktreeName == "/" || worktreeName == "" { + return nil, fmt.Errorf("could not derive worktree name from key %q", key) + } + } else if worktreeName == "." || worktreeName == "/" { + return nil, fmt.Errorf("invalid worktree name override %q", worktreeName) + } + + // 3) HEAD on S3 to determine size and meta.SHA256. + s3Client, err := newS3Client(ctx, in) + if err != nil { + return nil, err + } + head, err := s3Client.HeadObject(ctx, &s3.HeadObjectInput{ + Bucket: aws.String(bucket), + Key: aws.String(key), + }) + if err != nil { + return nil, fmt.Errorf("s3 HeadObject failed (bucket=%q key=%q): %w", bucket, key, err) + } + + metaSHA := extractSHA256FromMetadata(head.Metadata) + + // Optional: validate provided SHA256 against metadata if both exist. + expected := normalizeSHA256(in.SHA256) + if expected != "" && metaSHA != "" && !strings.EqualFold(expected, metaSHA) { + return nil, fmt.Errorf("sha256 mismatch: expected=%s head.meta=%s", expected, metaSHA) + } + + var lm time.Time + if head.LastModified != nil { + lm = *head.LastModified + } + + if head.ContentLength == nil { + return nil, fmt.Errorf("s3 HeadObject missing ContentLength (bucket=%q key=%q)", bucket, key) + } + sizeBytes := *head.ContentLength + + var etag string + if head.ETag != nil { + etag = strings.Trim(*head.ETag, `"`) + } + + out := &S3Object{ + Bucket: bucket, + Key: key, + Path: worktreeName, + SizeBytes: sizeBytes, + MetaSHA256: metaSHA, + ETag: etag, + LastModTime: lm, + } + return out, nil +} + +// +// --- S3 parsing + client --- +// + +var virtualHostedRE = regexp.MustCompile(`^(.+?)\.s3(?:[.-]|$)`) + +// parseS3URL parses s3://bucket/key, virtual-hosted HTTPS (bucket.s3.../key) +// and path-style HTTPS (s3.../bucket/key). Returns bucket and key. +func parseS3URL(raw string) (string, string, error) { + u, err := url.Parse(raw) + if err != nil { + return "", "", err + } + + switch u.Scheme { + case "s3": + bucket := u.Host + key := strings.TrimPrefix(u.Path, "/") + return bucket, key, nil + case "http", "https": + host := u.Hostname() + + // virtual-hosted: bucket.s3.amazonaws.com or bucket.s3-region.amazonaws.com + if m := virtualHostedRE.FindStringSubmatch(host); m != nil { + bucket := m[1] + key := strings.TrimPrefix(u.Path, "/") + return bucket, key, nil + } + + // path-style: s3.../bucket/key + path := strings.TrimPrefix(u.Path, "/") + if path == "" { + return "", "", fmt.Errorf("no bucket in URL: %s", raw) + } + parts := strings.SplitN(path, "/", 2) + bucket := parts[0] + key := "" + if len(parts) == 2 { + key = parts[1] + } + return bucket, key, nil + default: + return "", "", fmt.Errorf("unsupported scheme: %s", u.Scheme) + } +} + +func newS3Client(ctx context.Context, in S3ObjectParameters) (*s3.Client, error) { + creds := credentials.NewStaticCredentialsProvider(in.AWSAccessKey, in.AWSSecretKey, "") + + // Custom HTTP client is useful for S3-compatible endpoints. + httpClient := &http.Client{ + Timeout: 60 * time.Second, + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + MinVersion: tls.VersionTLS12, + }, + }, + } + + cfg, err := config.LoadDefaultConfig(ctx, + config.WithRegion(in.AWSRegion), + config.WithCredentialsProvider(creds), + config.WithHTTPClient(httpClient), + ) + if err != nil { + return nil, fmt.Errorf("aws config init failed: %w", err) + } + + opts := []func(*s3.Options){} + if strings.TrimSpace(in.AWSEndpoint) != "" { + ep := strings.TrimRight(in.AWSEndpoint, "/") + opts = append(opts, func(o *s3.Options) { + o.UsePathStyle = true // usually required for Ceph/MinIO/custom endpoints + o.BaseEndpoint = aws.String(ep) + }) + } + + return s3.NewFromConfig(cfg, opts...), nil +} + +// +// --- SHA256 metadata extraction --- +// + +var sha256HexRe = regexp.MustCompile(`(?i)^[0-9a-f]{64}$`) + +func normalizeSHA256(s string) string { + s = strings.TrimSpace(s) + s = strings.TrimPrefix(strings.ToLower(s), "sha256:") + s = strings.TrimSpace(s) + if s == "" { + return "" + } + if !sha256HexRe.MatchString(s) { + // If caller provided something malformed, treat as empty. + // Change this to a hard error if you prefer. + return "" + } + return strings.ToLower(s) +} + +func extractSHA256FromMetadata(md map[string]string) string { + if md == nil { + return "" + } + + // AWS SDK v2 exposes user-defined metadata WITHOUT the "x-amz-meta-" prefix, + // and normalizes keys to lower-case. + candidates := []string{ + "sha256", + "checksum-sha256", + "content-sha256", + "oid-sha256", + "git-lfs-sha256", + } + + for _, k := range candidates { + if v, ok := md[k]; ok { + n := normalizeSHA256(v) + if n != "" { + return n + } + } + } + + // Sometimes people stash "sha256:" + for _, v := range md { + if n := normalizeSHA256(v); n != "" { + return n + } + } + + return "" +} diff --git a/cloud/inspect_test.go b/cloud/inspect_test.go new file mode 100644 index 00000000..77a5d449 --- /dev/null +++ b/cloud/inspect_test.go @@ -0,0 +1,123 @@ +package cloud + +import ( + "os" + "os/exec" + "strings" + "testing" +) + +func TestParseS3URL_S3Scheme(t *testing.T) { + b, k, err := parseS3URL("s3://my-bucket/path/to/file.bam") + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if b != "my-bucket" { + t.Fatalf("bucket mismatch: %q", b) + } + if k != "path/to/file.bam" { + t.Fatalf("key mismatch: %q", k) + } +} + +func TestParseS3URL_HTTPSPathStyle(t *testing.T) { + b, k, err := parseS3URL("https://s3.example.org/my-bucket/path/to/file.bam") + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if b != "my-bucket" { + t.Fatalf("bucket mismatch: %q", b) + } + if k != "path/to/file.bam" { + t.Fatalf("key mismatch: %q", k) + } +} + +func TestParseS3URL_HTTPSVirtualHosted(t *testing.T) { + b, k, err := parseS3URL("https://my-bucket.s3.amazonaws.com/path/to/file.bam") + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if b != "my-bucket" { + t.Fatalf("bucket mismatch: %q", b) + } + if k != "path/to/file.bam" { + t.Fatalf("key mismatch: %q", k) + } +} + +func TestNormalizeSHA256(t *testing.T) { + hex := "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" + + if got := normalizeSHA256(hex); got != hex { + t.Fatalf("expected %q, got %q", hex, got) + } + + if got := normalizeSHA256("sha256:" + strings.ToUpper(hex)); got != hex { + t.Fatalf("expected %q, got %q", hex, got) + } + + if got := normalizeSHA256("not-a-sha"); got != "" { + t.Fatalf("expected empty for invalid, got %q", got) + } +} + +func TestExtractSHA256FromMetadata_ByKey(t *testing.T) { + hex := "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" + + md := map[string]string{ + "sha256": hex, + } + got := extractSHA256FromMetadata(md) + if got != hex { + t.Fatalf("expected %q, got %q", hex, got) + } +} + +func TestExtractSHA256FromMetadata_ByAlternateKey(t *testing.T) { + hex := "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" + + md := map[string]string{ + "checksum-sha256": "sha256:" + hex, + } + got := extractSHA256FromMetadata(md) + if got != hex { + t.Fatalf("expected %q, got %q", hex, got) + } +} + +func TestExtractSHA256FromMetadata_SearchValues(t *testing.T) { + hex := "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" + + md := map[string]string{ + "something": "sha256:" + hex, + } + got := extractSHA256FromMetadata(md) + if got != hex { + t.Fatalf("expected %q, got %q", hex, got) + } +} + +// --- test helpers --- + +func mustRun(t *testing.T, dir string, name string, args ...string) { + t.Helper() + cmd := exec.Command(name, args...) + cmd.Dir = dir + out, err := cmd.CombinedOutput() + if err != nil { + t.Fatalf("command failed: %s %v\nerr=%v\nout=%s", name, args, err, string(out)) + } +} + +func mustChdir(t *testing.T, dir string) string { + t.Helper() + old, err := os.Getwd() + if err != nil { + t.Fatalf("Getwd: %v", err) + } + if err := os.Chdir(dir); err != nil { + t.Fatalf("Chdir(%s): %v", dir, err) + } + return old +} diff --git a/s3_utils/s3.go b/cloud/s3.go similarity index 61% rename from s3_utils/s3.go rename to cloud/s3.go index c6027a4d..5b828388 100644 --- a/s3_utils/s3.go +++ b/cloud/s3.go @@ -1,8 +1,10 @@ -package s3_utils +package cloud import ( + "fmt" "log/slog" "net/http" + "strings" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/s3" @@ -76,3 +78,38 @@ func WithLogger(logger *slog.Logger) AddURLOption { cfg.Logger = logger } } + +func ParseS3URL(s3url string) (string, string, error) { + s3Prefix := "s3://" + if !strings.HasPrefix(s3url, s3Prefix) { + return "", "", fmt.Errorf("S3 URL requires prefix 's3://': %s", s3url) + } + trimmed := strings.TrimPrefix(s3url, s3Prefix) + slashIndex := strings.Index(trimmed, "/") + if slashIndex == -1 || slashIndex == len(trimmed)-1 { + return "", "", fmt.Errorf("invalid S3 file URL: %s", s3url) + } + return trimmed[:slashIndex], trimmed[slashIndex+1:], nil +} + +// CanDownloadFile checks if a file can be downloaded from the given signed URL +// by issuing a ranged GET for a single byte to mimic HEAD behavior. +func CanDownloadFile(signedURL string) error { + req, err := http.NewRequest("GET", signedURL, nil) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Range", "bytes=0-0") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return fmt.Errorf("error while sending the request: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusPartialContent || resp.StatusCode == http.StatusOK { + return nil + } + + return fmt.Errorf("failed to access file, HTTP status: %d", resp.StatusCode) +} diff --git a/s3_utils/s3_test.go b/cloud/s3_test.go similarity index 53% rename from s3_utils/s3_test.go rename to cloud/s3_test.go index 42fda897..edcdc7f0 100644 --- a/s3_utils/s3_test.go +++ b/cloud/s3_test.go @@ -1,4 +1,4 @@ -package s3_utils +package cloud import ( "net/http" @@ -38,3 +38,33 @@ func TestAddURLOptions(t *testing.T) { t.Fatalf("unexpected config: %+v", cfg) } } + +func TestParseS3URL(t *testing.T) { + bucket, key, err := ParseS3URL("s3://my-bucket/path/to/file.txt") + if err != nil { + t.Fatalf("ParseS3URL error: %v", err) + } + if bucket != "my-bucket" || key != "path/to/file.txt" { + t.Fatalf("unexpected bucket/key: %s/%s", bucket, key) + } +} + +func TestParseS3URLErrors(t *testing.T) { + t.Run("missing prefix", func(t *testing.T) { + if _, _, err := ParseS3URL("http://bucket/key"); err == nil { + t.Fatalf("expected error for missing s3 prefix") + } + }) + + t.Run("missing key", func(t *testing.T) { + if _, _, err := ParseS3URL("s3://bucket"); err == nil { + t.Fatalf("expected error for missing key") + } + }) + + t.Run("trailing slash", func(t *testing.T) { + if _, _, err := ParseS3URL("s3://bucket/"); err == nil { + t.Fatalf("expected error for trailing slash") + } + }) +} diff --git a/s3_utils/validate.go b/cloud/validate.go similarity index 97% rename from s3_utils/validate.go rename to cloud/validate.go index f67fcd9b..c876ea9d 100644 --- a/s3_utils/validate.go +++ b/cloud/validate.go @@ -1,4 +1,4 @@ -package s3_utils +package cloud import ( "encoding/hex" diff --git a/s3_utils/validate_test.go b/cloud/validate_test.go similarity index 95% rename from s3_utils/validate_test.go rename to cloud/validate_test.go index d2493d98..2c08270b 100644 --- a/s3_utils/validate_test.go +++ b/cloud/validate_test.go @@ -1,4 +1,4 @@ -package s3_utils +package cloud import "testing" diff --git a/cmd/addref/add-ref.go b/cmd/addref/add-ref.go index 86b9bfec..3f28c48e 100644 --- a/cmd/addref/add-ref.go +++ b/cmd/addref/add-ref.go @@ -6,8 +6,8 @@ import ( "os" "path/filepath" - "github.com/calypr/data-client/indexd/drs" - "github.com/calypr/data-client/indexd/hash" + "github.com/calypr/data-client/drs" + "github.com/calypr/data-client/hash" "github.com/calypr/git-drs/config" "github.com/calypr/git-drs/drslog" "github.com/spf13/cobra" diff --git a/cmd/addref/add-ref_test.go b/cmd/addref/add-ref_test.go index 198deee8..2167499f 100644 --- a/cmd/addref/add-ref_test.go +++ b/cmd/addref/add-ref_test.go @@ -5,8 +5,8 @@ import ( "path/filepath" "testing" - "github.com/calypr/data-client/indexd/drs" - "github.com/calypr/data-client/indexd/hash" + "github.com/calypr/data-client/drs" + "github.com/calypr/data-client/hash" ) func TestCreateLfsPointer(t *testing.T) { diff --git a/cmd/addurl/cache.go b/cmd/addurl/cache.go new file mode 100644 index 00000000..968011ca --- /dev/null +++ b/cmd/addurl/cache.go @@ -0,0 +1,264 @@ +package addurl + +import ( + "context" + "crypto/sha256" + "encoding/json" + "errors" + "fmt" + "log/slog" + "maps" + "os" + "path/filepath" + "slices" + "strings" + "time" + + "github.com/calypr/git-drs/gitrepo" + "github.com/calypr/git-drs/precommit_cache" +) + +// updatePrecommitCache updates the project's pre-commit cache with a mapping +// from a repository-relative `pathArg` to the given LFS `oid` and records the +// external source URL. It will: +// - require a non-nil `logger` +// - open the pre-commit cache (`precommit_cache.Open`) +// - ensure cache directories exist +// - convert the supplied worktree path to a repository-relative path +// - create or update the per-path JSON entry with the current OID and timestamp +// - create or update the per-OID JSON entry listing paths that reference it, +// the external URL, and a content-change flag when the path's OID changed +// - remove the path from the previous OID entry when the content changed +// +// Parameters: +// - ctx: context for operations that may be cancellable +// - logger: a non-nil `*slog.Logger` used for warnings; if nil the function +// returns an error +// - pathArg: the worktree path to record (absolute or relative); must not be empty +// - oid: the LFS object id (string) to associate with the path +// - externalURL: optional external source URL for the object; empty string is allowed +// +// Returns an error if any cache operation, path resolution, or I/O fails. +func updatePrecommitCache(ctx context.Context, logger *slog.Logger, pathArg, oid, externalURL string) error { + if logger == nil { + return errors.New("logger is required") + } + // Open pre-commit cache. Returns a configured Cache or error. + cache, err := precommit_cache.Open(ctx) + if err != nil { + return err + } + + // Ensure cache directories exist. + if err := ensureCacheDirs(cache, logger); err != nil { + return err + } + + // Convert worktree path to repository-relative path. + relPath, err := repoRelativePath(pathArg) + if err != nil { + return err + } + + // Current timestamp in RFC3339 format (UTC). + now := time.Now().UTC().Format(time.RFC3339) + + // Read previous path entry, if any. + pathFile := cachePathEntryFile(cache, relPath) + prevEntry, prevExists, err := readPathEntry(pathFile) + if err != nil { + return err + } + // track whether content changed for this path + contentChanged := prevExists && prevEntry.LFSOID != "" && prevEntry.LFSOID != oid + + if err := writeJSONAtomic(pathFile, precommit_cache.PathEntry{ + Path: relPath, + LFSOID: oid, + UpdatedAt: now, + }); err != nil { + return err + } + + if err := upsertOIDEntry(cache, oid, relPath, externalURL, now, contentChanged); err != nil { + return err + } + + if contentChanged { + _ = removePathFromOID(cache, prevEntry.LFSOID, relPath, now) + } + + return nil +} + +// ensureCacheDirs verifies and creates the pre-commit cache directory layout +// (paths and oids directories). It logs a warning when creating a missing +// cache root. +func ensureCacheDirs(cache *precommit_cache.Cache, logger *slog.Logger) error { + if cache == nil { + return errors.New("cache is nil") + } + if _, err := os.Stat(cache.Root); err != nil { + if os.IsNotExist(err) { + logger.Warn("pre-commit cache directory missing; creating", "path", cache.Root) + } else { + return err + } + } + if err := os.MkdirAll(cache.PathsDir, 0o755); err != nil { + return fmt.Errorf("create cache paths dir: %w", err) + } + if err := os.MkdirAll(cache.OIDsDir, 0o755); err != nil { + return fmt.Errorf("create cache oids dir: %w", err) + } + return nil +} + +// repoRelativePath converts a worktree path (absolute or relative) to a +// repository-relative path. It resolves symlinks and ensures the path is +// contained within the repository root. +func repoRelativePath(pathArg string) (string, error) { + if pathArg == "" { + return "", errors.New("empty worktree path") + } + root, err := gitrepo.GitTopLevel() + if err != nil { + return "", err + } + root, err = filepath.EvalSymlinks(root) + if err != nil { + return "", err + } + clean := filepath.Clean(pathArg) + if filepath.IsAbs(clean) { + clean, err = filepath.EvalSymlinks(clean) + if err != nil { + return "", err + } + rel, err := filepath.Rel(root, clean) + if err != nil { + return "", err + } + if strings.HasPrefix(rel, "..") { + return "", fmt.Errorf("path %s is outside repo root %s", clean, root) + } + return filepath.ToSlash(rel), nil + } + return filepath.ToSlash(clean), nil +} + +// cachePathEntryFile returns the filesystem path to the JSON path-entry file +// for the given repository-relative path within the provided Cache. +func cachePathEntryFile(cache *precommit_cache.Cache, path string) string { + return filepath.Join(cache.PathsDir, precommit_cache.EncodePath(path)+".json") +} + +// cacheOIDEntryFile returns the filesystem path to the JSON OID-entry file +// for the given LFS OID. The file is named by sha256(oid) to avoid filesystem +// restrictions and collisions. +func cacheOIDEntryFile(cache *precommit_cache.Cache, oid string) string { + sum := sha256.Sum256([]byte(oid)) + return filepath.Join(cache.OIDsDir, fmt.Sprintf("%x.json", sum[:])) +} + +// readPathEntry reads and parses a JSON PathEntry from disk. It returns the +// parsed entry, a boolean indicating existence, or an error on I/O/parse +// failure. +func readPathEntry(path string) (*precommit_cache.PathEntry, bool, error) { + data, err := os.ReadFile(path) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + return nil, false, nil + } + return nil, false, err + } + var entry precommit_cache.PathEntry + if err := json.Unmarshal(data, &entry); err != nil { + return nil, false, err + } + return &entry, true, nil +} + +// readOIDEntry reads and parses a JSON OIDEntry from disk. If the file is +// missing it returns a freshly initialized entry (with LFSOID set to the +// supplied oid and UpdatedAt set to now). +func readOIDEntry(path string, oid string, now string) (*precommit_cache.OIDEntry, error) { + entry := &precommit_cache.OIDEntry{ + LFSOID: oid, + Paths: []string{}, + UpdatedAt: now, + } + data, err := os.ReadFile(path) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + return entry, nil + } + return nil, err + } + if err := json.Unmarshal(data, entry); err != nil { + return nil, err + } + entry.LFSOID = oid + return entry, nil +} + +// upsertOIDEntry creates or updates the OID entry for `oid`, ensuring `path` +// is listed among its Paths, updating ExternalURL when provided, and setting +// content-change/state fields. The updated entry is written atomically. +func upsertOIDEntry(cache *precommit_cache.Cache, oid, path, externalURL, now string, contentChanged bool) error { + if cache == nil { + return errors.New("cache is nil") + } + oidFile := cacheOIDEntryFile(cache, oid) + entry, err := readOIDEntry(oidFile, oid, now) + if err != nil { + return err + } + + pathSet := make(map[string]struct{}, len(entry.Paths)+1) + for _, p := range entry.Paths { + pathSet[p] = struct{}{} + } + if path != "" { + pathSet[path] = struct{}{} + } + entry.Paths = sortedKeys(pathSet) + entry.UpdatedAt = now + entry.ContentChange = entry.ContentChange || contentChanged + if strings.TrimSpace(externalURL) != "" { + entry.ExternalURL = externalURL + } + + return writeJSONAtomic(oidFile, entry) +} + +// removePathFromOID removes `path` from the OID entry for `oid` and writes +// the updated entry atomically. Missing entries are treated as empty. +func removePathFromOID(cache *precommit_cache.Cache, oid, path, now string) error { + if cache == nil { + return errors.New("cache is nil") + } + oidFile := cacheOIDEntryFile(cache, oid) + entry, err := readOIDEntry(oidFile, oid, now) + if err != nil { + return err + } + pathSet := make(map[string]struct{}, len(entry.Paths)) + for _, p := range entry.Paths { + if p == path { + continue + } + pathSet[p] = struct{}{} + } + entry.Paths = sortedKeys(pathSet) + entry.UpdatedAt = now + + return writeJSONAtomic(oidFile, entry) +} + +// sortedKeys returns a sorted slice of keys from the provided string-set map. +func sortedKeys(set map[string]struct{}) []string { + keys := slices.Collect(maps.Keys(set)) + slices.Sort(keys) + return keys +} diff --git a/cmd/addurl/io.go b/cmd/addurl/io.go new file mode 100644 index 00000000..dda6e824 --- /dev/null +++ b/cmd/addurl/io.go @@ -0,0 +1,120 @@ +package addurl + +import ( + "context" + "encoding/json" + "fmt" + "os" + "path/filepath" + + "github.com/calypr/git-drs/cloud" + "github.com/spf13/cobra" +) + +// writePointerFile writes a Git LFS pointer file at the given worktree path +// referencing the supplied oid and recording sizeBytes. It creates parent +// directories as needed and validates the path is non-empty. +func writePointerFile(pathArg, oid string, sizeBytes int64) error { + pointer := fmt.Sprintf( + "version https://git-lfs.github.com/spec/v1\noid sha256:%s\nsize %d\n", + oid, sizeBytes, + ) + if pathArg == "" { + return fmt.Errorf("empty worktree path") + } + safePath := filepath.Clean(pathArg) + dir := filepath.Dir(safePath) + if dir != "." && dir != "/" { + if err := os.MkdirAll(dir, 0755); err != nil { + return fmt.Errorf("mkdir %s: %w", dir, err) + } + } + if err := os.WriteFile(safePath, []byte(pointer), 0644); err != nil { + return fmt.Errorf("write %s: %w", safePath, err) + } + + if _, err := fmt.Fprintf(os.Stderr, "Added Git LFS pointer file at %s\n", safePath); err != nil { + return fmt.Errorf("stderr write: %w", err) + } + return nil +} + +// maybeTrackLFS ensures the supplied path is tracked by Git LFS by invoking +// the provided gitLFSTrack callback when the path is not already tracked. +// It reports the addition to stderr for user guidance. +func maybeTrackLFS(ctx context.Context, gitLFSTrack func(context.Context, string) (bool, error), pathArg string, isTracked bool) error { + if isTracked { + return nil + } + if _, err := gitLFSTrack(ctx, pathArg); err != nil { + return fmt.Errorf("git lfs track %s: %w", pathArg, err) + } + + if _, err := fmt.Fprintf(os.Stderr, "Info: Added to Git LFS. Remember to `git add %s` and `git commit ...`", pathArg); err != nil { + return fmt.Errorf("stderr write: %w", err) + } + return nil +} + +// printResolvedInfo writes a human-readable summary of resolved Git/LFS and +// S3 object information to the command's stdout for user confirmation. +func printResolvedInfo(cmd *cobra.Command, gitCommonDir, lfsRoot string, s3Info *cloud.S3Object, pathArg string, isTracked bool, sha256 string) error { + if _, err := fmt.Fprintf(cmd.OutOrStdout(), ` +Resolved Git LFS s3Info +--------------------- +Git common dir : %s +LFS storage : %s + +S3 object +--------- +Bucket : %s +Key : %s +Worktree name : %s +Size (bytes) : %d +SHA256 (meta) : %s +ETag : %s +Last modified : %s + +Worktree +------------- +path : %s +tracked by LFS : %v +sha256 param : %s + +`, + gitCommonDir, + lfsRoot, + s3Info.Bucket, + s3Info.Key, + s3Info.Path, + s3Info.SizeBytes, + s3Info.MetaSHA256, + s3Info.ETag, + s3Info.LastModTime.Format("2006-01-02T15:04:05Z07:00"), + pathArg, + isTracked, + sha256, + ); err != nil { + return fmt.Errorf("print resolved s3Info: %w", err) + } + return nil +} + +// writeJSONAtomic marshals `value` to JSON and writes it to `path` atomically +// by writing to a temporary file in the same directory and renaming it. It +// ensures parent directories exist. +func writeJSONAtomic(path string, value any) error { + dir := filepath.Dir(path) + if err := os.MkdirAll(dir, 0o755); err != nil { + return err + } + tmp := path + ".tmp" + data, err := json.Marshal(value) + if err != nil { + return err + } + if err := os.WriteFile(tmp, data, 0o644); err != nil { + return err + } + return os.Rename(tmp, path) +} diff --git a/cmd/addurl/main.go b/cmd/addurl/main.go index 6d9ab4eb..5b503702 100644 --- a/cmd/addurl/main.go +++ b/cmd/addurl/main.go @@ -1,118 +1,158 @@ package addurl import ( + "context" "errors" "fmt" + "net/url" "os" - "path/filepath" + "strings" - "github.com/calypr/git-drs/config" - "github.com/calypr/git-drs/drslog" - "github.com/calypr/git-drs/gitrepo" - "github.com/calypr/git-drs/s3_utils" - "github.com/calypr/git-drs/utils" + "github.com/calypr/git-drs/cloud" "github.com/spf13/cobra" ) -// AddURLCmd represents the add-url command -var AddURLCmd = &cobra.Command{ - Use: "add-url ", - Short: "Add a file to the Git DRS repo using an S3 URL", - Args: func(cmd *cobra.Command, args []string) error { - if len(args) != 2 { - cmd.SilenceUsage = false - return fmt.Errorf("error: requires exactly 2 arguments (S3 URL and SHA256), received %d\n\nUsage: %s\n\nSee 'git drs add-url --help' for more details", len(args), cmd.UseLine()) - } - return nil - }, - RunE: func(cmd *cobra.Command, args []string) error { - myLogger := drslog.GetLogger() - - // set git config lfs.allowincompletepush = true - // set git config lfs.allowincompletepush = true - if err := gitrepo.SetGitConfigOptions(map[string]string{"lfs.allowincompletepush": "true"}); err != nil { - return fmt.Errorf("unable to configure git to push pointers: %v", err) - } - - // Parse arguments - s3URL := args[0] - sha256 := args[1] - awsAccessKey, _ := cmd.Flags().GetString(s3_utils.AWS_KEY_FLAG_NAME) - awsSecretKey, _ := cmd.Flags().GetString(s3_utils.AWS_SECRET_FLAG_NAME) - awsRegion, _ := cmd.Flags().GetString(s3_utils.AWS_REGION_FLAG_NAME) - awsEndpoint, _ := cmd.Flags().GetString(s3_utils.AWS_ENDPOINT_URL_FLAG_NAME) - remote, _ := cmd.Flags().GetString("remote") - - // if providing credentials, access key and secret must both be provided - if (awsAccessKey == "" && awsSecretKey != "") || (awsAccessKey != "" && awsSecretKey == "") { - return errors.New("incomplete credentials provided as environment variables. Please run `export " + s3_utils.AWS_KEY_ENV_VAR + "=` and `export " + s3_utils.AWS_SECRET_ENV_VAR + "=` to configure both") - } - - // if none provided, use default AWS configuration on file - if awsAccessKey == "" && awsSecretKey == "" { - myLogger.Debug("No AWS credentials provided. Using default AWS configuration from file.") - } - - cfg, err := config.LoadConfig() - if err != nil { - return fmt.Errorf("error loading config: %v", err) - } - - remoteName, err := cfg.GetRemoteOrDefault(remote) - if err != nil { - return fmt.Errorf("error getting default remote: %v", err) - } - - drsClient, err := cfg.GetRemoteClient(remoteName, myLogger) - if err != nil { - return fmt.Errorf("error getting current remote client: %v", err) - } - - // Call client.AddURL to handle Gen3 interactions - meta, err := drsClient.AddURL(s3URL, sha256, awsAccessKey, awsSecretKey, awsRegion, awsEndpoint) - if err != nil { - return err - } - - // Generate and add pointer file - _, relFilePath, err := utils.ParseS3URL(s3URL) - if err != nil { - return fmt.Errorf("failed to parse S3 URL: %w", err) - } - if err := generatePointerFile(relFilePath, sha256, meta.Size); err != nil { - return fmt.Errorf("failed to generate pointer file: %w", err) - } - myLogger.Debug("S3 URL successfully added to Git DRS repo.") - return nil - }, +var Cmd = NewCommand() + +// NewCommand constructs the Cobra command for the `add-url` subcommand, +// wiring usage, argument validation and the RunE handler. +func NewCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: "add-url [path]", + Short: "Add a file to the Git DRS repo using an S3 URL", + Args: func(cmd *cobra.Command, args []string) error { + if len(args) < 1 || len(args) > 2 { + return errors.New("usage: add-url [path]") + } + return nil + }, + RunE: runAddURL, + } + addFlags(cmd) + return cmd +} + +// addFlags registers command-line flags for AWS credentials, endpoint and an +// optional `sha256` expected checksum. +func addFlags(cmd *cobra.Command) { + cmd.Flags().String( + cloud.AWS_KEY_FLAG_NAME, + os.Getenv(cloud.AWS_KEY_ENV_VAR), + "AWS access key", + ) + + cmd.Flags().String( + cloud.AWS_SECRET_FLAG_NAME, + os.Getenv(cloud.AWS_SECRET_ENV_VAR), + "AWS secret key", + ) + + cmd.Flags().String( + cloud.AWS_REGION_FLAG_NAME, + os.Getenv(cloud.AWS_REGION_ENV_VAR), + "AWS S3 region", + ) + + cmd.Flags().String( + cloud.AWS_ENDPOINT_URL_FLAG_NAME, + os.Getenv(cloud.AWS_ENDPOINT_URL_ENV_VAR), + "AWS S3 endpoint (optional, for Ceph/MinIO)", + ) + + // New flag: optional expected SHA256 + cmd.Flags().String( + "sha256", + "", + "Expected SHA256 checksum (optional)", + ) +} + +// runAddURL is the Cobra RunE wrapper that delegates execution to the +func runAddURL(cmd *cobra.Command, args []string) (err error) { + return NewAddURLService().Run(cmd, args) } -func init() { - AddURLCmd.Flags().String(s3_utils.AWS_KEY_FLAG_NAME, os.Getenv(s3_utils.AWS_KEY_ENV_VAR), "AWS access key") - AddURLCmd.Flags().String(s3_utils.AWS_SECRET_FLAG_NAME, os.Getenv(s3_utils.AWS_SECRET_ENV_VAR), "AWS secret key") - AddURLCmd.Flags().String(s3_utils.AWS_REGION_FLAG_NAME, os.Getenv(s3_utils.AWS_REGION_ENV_VAR), "AWS S3 region") - AddURLCmd.Flags().String(s3_utils.AWS_ENDPOINT_URL_FLAG_NAME, os.Getenv(s3_utils.AWS_ENDPOINT_URL_ENV_VAR), "AWS S3 endpoint") - AddURLCmd.Flags().String("remote", "", "target remote DRS server (default: default_remote)") +// download uses cloud.AgentFetchReader to download the S3 object, returning +// the computed SHA256 and the path to the temporary downloaded file. +// The caller is responsible for moving/deleting the temporary file. +// we include this wrapper function to allow mocking in tests. +var download = func(ctx context.Context, info *cloud.S3Object, input cloud.S3ObjectParameters, lfsRoot string) (string, string, error) { + return cloud.Download(ctx, info, input, lfsRoot) } -func generatePointerFile(filePath string, sha256 string, fileSize int64) error { - // Define the pointer file content - pointerContent := fmt.Sprintf("version https://git-lfs.github.com/spec/v1\noid sha256:%s\nsize %d\n", sha256, fileSize) +// addURLInput parses CLI args and flags into an addURLInput, validates +// required AWS credentials and region, and constructs cloud.S3ObjectParameters. +type addURLInput struct { + s3URL string + path string + sha256 string + s3Params cloud.S3ObjectParameters +} + +func parseAddURLInput(cmd *cobra.Command, args []string) (addURLInput, error) { + s3URL := args[0] - // Ensure the directory exists - if err := os.MkdirAll(filepath.Dir(filePath), 0755); err != nil { - return fmt.Errorf("failed to create directory for pointer file: %w", err) + pathArg, err := resolvePathArg(s3URL, args) + if err != nil { + return addURLInput{}, err } - // Write the pointer file - if err := os.WriteFile(filePath, []byte(pointerContent), 0644); err != nil { - return fmt.Errorf("failed to write pointer file: %w", err) + sha256Param, err := cmd.Flags().GetString("sha256") + if err != nil { + return addURLInput{}, fmt.Errorf("read flag sha256: %w", err) } - // Add the pointer file to Git - if err := gitrepo.AddFile(filePath); err != nil { - return fmt.Errorf("failed to add pointer file to Git: %w", err) + awsKey, err := cmd.Flags().GetString(cloud.AWS_KEY_FLAG_NAME) + if err != nil { + return addURLInput{}, fmt.Errorf("read flag %s: %w", cloud.AWS_KEY_FLAG_NAME, err) + } + awsSecret, err := cmd.Flags().GetString(cloud.AWS_SECRET_FLAG_NAME) + if err != nil { + return addURLInput{}, fmt.Errorf("read flag %s: %w", cloud.AWS_SECRET_FLAG_NAME, err) + } + awsRegion, err := cmd.Flags().GetString(cloud.AWS_REGION_FLAG_NAME) + if err != nil { + return addURLInput{}, fmt.Errorf("read flag %s: %w", cloud.AWS_REGION_FLAG_NAME, err) + } + awsEndpoint, err := cmd.Flags().GetString(cloud.AWS_ENDPOINT_URL_FLAG_NAME) + if err != nil { + return addURLInput{}, fmt.Errorf("read flag %s: %w", cloud.AWS_ENDPOINT_URL_FLAG_NAME, err) } - return nil + if awsKey == "" || awsSecret == "" { + return addURLInput{}, errors.New("AWS credentials must be provided via flags or environment variables") + } + if awsRegion == "" { + return addURLInput{}, errors.New("AWS region must be provided via flag or environment variable") + } + + s3Input := cloud.S3ObjectParameters{ + S3URL: s3URL, + AWSAccessKey: awsKey, + AWSSecretKey: awsSecret, + AWSRegion: awsRegion, + AWSEndpoint: awsEndpoint, + SHA256: sha256Param, + DestinationPath: pathArg, + } + + return addURLInput{ + s3URL: s3URL, + path: pathArg, + sha256: sha256Param, + s3Params: s3Input, + }, nil +} + +// resolvePathArg returns the explicit destination path argument when provided, +// otherwise derives the worktree path from the given S3 URL path component. +func resolvePathArg(s3URL string, args []string) (string, error) { + if len(args) == 2 { + return args[1], nil + } + u, err := url.Parse(s3URL) + if err != nil { + return "", err + } + return strings.TrimPrefix(u.Path, "/"), nil } diff --git a/cmd/addurl/main_test.go b/cmd/addurl/main_test.go index 759c01dd..5c1fe607 100644 --- a/cmd/addurl/main_test.go +++ b/cmd/addurl/main_test.go @@ -1,28 +1,352 @@ package addurl import ( + "bytes" + "context" + "crypto/sha256" + "encoding/json" + "fmt" + "io" + "log/slog" "os" + "os/exec" "path/filepath" + "strings" "testing" + "time" - "github.com/calypr/git-drs/internal/testutils" + "github.com/calypr/git-drs/cloud" + "github.com/calypr/git-drs/config" + "github.com/calypr/git-drs/drsmap" + "github.com/calypr/git-drs/precommit_cache" + "github.com/spf13/cobra" ) -func TestGeneratePointerFile(t *testing.T) { - testutils.SetupTestGitRepo(t) +func TestRunAddURL_WritesPointerAndLFSObject(t *testing.T) { + content := "hello world" + sum := sha256.Sum256([]byte(content)) + shaHex := fmt.Sprintf("%x", sum[:]) - path := filepath.Join("data", "file.txt") - err := generatePointerFile(path, "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", 10) + tempDir := t.TempDir() + lfsRoot := filepath.Join(tempDir, ".git", "lfs") + + // ensure a git repository exists so any git-based config lookups succeed + cmdInit := exec.Command("git", "init") + cmdInit.Dir = tempDir + if out, err := cmdInit.CombinedOutput(); err != nil { + t.Fatalf("git init failed: %v: %s", err, out) + } + + // Setup config for test + origDir, _ := os.Getwd() + exec.Command("git", "init", tempDir).Run() + os.Chdir(tempDir) + defer os.Chdir(origDir) + + // Mock config + // Create dummy config using git config + // create a minimal drs config so runAddURL doesn't fail with + //default_remote: calypr-dev + //remotes: + // calypr-dev: + // gen3: + // endpoint: https://calypr-dev.ohsu.edu + // project_id: cbds-monorepos + // bucket: cbds + + cmds := [][]string{ + {"config", "lfs.customtransfer.drs.default-remote", "calypr-dev"}, + {"config", "lfs.customtransfer.drs.remote.calypr-dev.type", "gen3"}, + {"config", "lfs.customtransfer.drs.remote.calypr-dev.project", "calypr-dev"}, + {"config", "lfs.customtransfer.drs.remote.calypr-dev.endpoint", "https://calypr-dev.ohsu.edu"}, + {"config", "lfs.customtransfer.drs.remote.calypr-dev.bucket", "cbds"}, + } + for _, args := range cmds { + cmd := exec.Command("git", args...) + cmd.Dir = tempDir + if err := cmd.Run(); err != nil { + t.Fatalf("git %v failed: %v", args, err) + } + } + loaded, err := config.LoadConfig() if err != nil { - t.Fatalf("generatePointerFile error: %v", err) + t.Fatalf("LoadConfig error: %v", err) + } + if loaded.DefaultRemote != "calypr-dev" { + t.Fatalf("expected default remote set, got %s", loaded.DefaultRemote) + } + + service := NewAddURLService() + resetStubs := stubAddURLDeps(t, service, + func(ctx context.Context, in cloud.S3ObjectParameters) (*cloud.S3Object, error) { + return &cloud.S3Object{ + Bucket: "bucket", + Key: "path/to/file.bin", + Path: "file.bin", + SizeBytes: int64(len(content)), + MetaSHA256: "", + ETag: "abcd1234", + LastModTime: time.Date(2024, 1, 2, 3, 4, 5, 0, time.UTC), + }, nil + }, + func(path string) (bool, error) { + return true, nil + }, + // download stub: write the LFS object into lfsRoot and return the sha + func(ctx context.Context, info *cloud.S3Object, input cloud.S3ObjectParameters, lfsRootPath string) (string, string, error) { + objPath := filepath.Join(lfsRootPath, "objects", shaHex[0:2], shaHex[2:4], shaHex) + if err := os.MkdirAll(filepath.Dir(objPath), 0755); err != nil { + return "", "", err + } + if err := os.WriteFile(objPath, []byte(content), 0644); err != nil { + return "", "", err + } + return shaHex, objPath, nil + }, + ) + t.Cleanup(resetStubs) + + cmd := NewCommand() + var out bytes.Buffer + cmd.SetOut(&out) + requireFlags(t, cmd) + + oldwd := mustChdir(t, tempDir) + t.Cleanup(func() { _ = os.Chdir(oldwd) }) + + if err := service.Run(cmd, []string{"s3://bucket/path/to/file.bin"}); err != nil { + t.Fatalf("service.Run error: %v", err) } - content, err := os.ReadFile(path) + pointerPath := filepath.Join(tempDir, "path/to/file.bin") + pointerBytes, err := os.ReadFile(pointerPath) if err != nil { t.Fatalf("read pointer file: %v", err) } + expectedPointer := fmt.Sprintf( + "version https://git-lfs.github.com/spec/v1\noid sha256:%s\nsize %d\n", + shaHex, + len(content), + ) + if string(pointerBytes) != expectedPointer { + t.Fatalf("pointer mismatch: expected %q, got %q", expectedPointer, string(pointerBytes)) + } + + lfsObject := filepath.Join(lfsRoot, "objects", shaHex[0:2], shaHex[2:4], shaHex) + if _, err := os.Stat(lfsObject); err != nil { + t.Fatalf("expected LFS object at %s: %v", lfsObject, err) + } + + drsObject, err := drsmap.DrsInfoFromOid(shaHex) + if err != nil { + t.Fatalf("read drs object: %v", err) + } + if len(drsObject.AccessMethods) == 0 { + t.Fatalf("expected access methods in drs object") + } + if got := drsObject.AccessMethods[0].AccessURL.URL; got != "s3://bucket/path/to/file.bin" { + t.Fatalf("unexpected access URL: %s", got) + } +} + +func TestUpdatePrecommitCacheWritesEntries(t *testing.T) { + repo := setupGitRepo(t) + path := filepath.Join(repo, "data", "file.bin") + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + t.Fatalf("mkdir: %v", err) + } + if err := os.WriteFile(path, []byte("data"), 0o644); err != nil { + t.Fatalf("write: %v", err) + } + + oldwd := mustChdir(t, repo) + t.Cleanup(func() { _ = os.Chdir(oldwd) }) + + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + oid := "sha256deadbeef" + externalURL := "s3://bucket/data/file.bin" + + if err := updatePrecommitCache(context.Background(), logger, path, oid, externalURL); err != nil { + t.Fatalf("updatePrecommitCache: %v", err) + } + + cacheRoot := filepath.Join(repo, ".git", "drs", "pre-commit", "v1") + pathsDir := filepath.Join(cacheRoot, "paths") + oidDir := filepath.Join(cacheRoot, "oids") + + pathEntryFile := filepath.Join(pathsDir, precommit_cache.EncodePath("data/file.bin")+".json") + pathData, err := os.ReadFile(pathEntryFile) + if err != nil { + t.Fatalf("read path entry: %v", err) + } + var pathEntry precommit_cache.PathEntry + if err := json.Unmarshal(pathData, &pathEntry); err != nil { + t.Fatalf("unmarshal path entry: %v", err) + } + if pathEntry.Path != "data/file.bin" { + t.Fatalf("expected path entry path to be %q, got %q", "data/file.bin", pathEntry.Path) + } + if pathEntry.LFSOID != oid { + t.Fatalf("expected path entry oid to be %q, got %q", oid, pathEntry.LFSOID) + } + if pathEntry.UpdatedAt == "" { + t.Fatalf("expected updated_at to be set") + } + + oidSum := sha256.Sum256([]byte(oid)) + oidEntryFile := filepath.Join(oidDir, fmt.Sprintf("%x.json", oidSum[:])) + oidData, err := os.ReadFile(oidEntryFile) + if err != nil { + t.Fatalf("read oid entry: %v", err) + } + var oidEntry precommit_cache.OIDEntry + if err := json.Unmarshal(oidData, &oidEntry); err != nil { + t.Fatalf("unmarshal oid entry: %v", err) + } + if oidEntry.LFSOID != oid { + t.Fatalf("expected oid entry oid to be %q, got %q", oid, oidEntry.LFSOID) + } + if oidEntry.ExternalURL != externalURL { + t.Fatalf("expected oid entry external_url to be %q, got %q", externalURL, oidEntry.ExternalURL) + } + if len(oidEntry.Paths) != 1 || oidEntry.Paths[0] != "data/file.bin" { + t.Fatalf("expected oid entry paths to include data/file.bin, got %v", oidEntry.Paths) + } + if oidEntry.UpdatedAt == "" { + t.Fatalf("expected oid entry updated_at to be set") + } +} + +func TestUpdatePrecommitCacheContentChanged(t *testing.T) { + repo := setupGitRepo(t) + path := filepath.Join(repo, "data", "file.bin") + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + t.Fatalf("mkdir: %v", err) + } + if err := os.WriteFile(path, []byte("data"), 0o644); err != nil { + t.Fatalf("write: %v", err) + } + + oldwd := mustChdir(t, repo) + t.Cleanup(func() { _ = os.Chdir(oldwd) }) + + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + firstOID := "sha256first" + secondOID := "sha256second" + + if err := updatePrecommitCache(context.Background(), logger, path, firstOID, "s3://bucket/first"); err != nil { + t.Fatalf("updatePrecommitCache first: %v", err) + } + if err := updatePrecommitCache(context.Background(), logger, path, secondOID, "s3://bucket/second"); err != nil { + t.Fatalf("updatePrecommitCache second: %v", err) + } - if len(content) == 0 { - t.Fatalf("expected pointer file content") + cacheRoot := filepath.Join(repo, ".git", "drs", "pre-commit", "v1") + oidDir := filepath.Join(cacheRoot, "oids") + + firstSum := sha256.Sum256([]byte(firstOID)) + firstEntryFile := filepath.Join(oidDir, fmt.Sprintf("%x.json", firstSum[:])) + firstData, err := os.ReadFile(firstEntryFile) + if err != nil { + t.Fatalf("read first oid entry: %v", err) + } + var firstEntry precommit_cache.OIDEntry + if err := json.Unmarshal(firstData, &firstEntry); err != nil { + t.Fatalf("unmarshal first oid entry: %v", err) + } + if len(firstEntry.Paths) != 0 { + t.Fatalf("expected old oid entry paths to be empty, got %v", firstEntry.Paths) + } + + secondSum := sha256.Sum256([]byte(secondOID)) + secondEntryFile := filepath.Join(oidDir, fmt.Sprintf("%x.json", secondSum[:])) + secondData, err := os.ReadFile(secondEntryFile) + if err != nil { + t.Fatalf("read second oid entry: %v", err) + } + var secondEntry precommit_cache.OIDEntry + if err := json.Unmarshal(secondData, &secondEntry); err != nil { + t.Fatalf("unmarshal second oid entry: %v", err) + } + if !secondEntry.ContentChange { + t.Fatalf("expected content_changed to be true") + } + if len(secondEntry.Paths) != 1 || secondEntry.Paths[0] != "data/file.bin" { + t.Fatalf("expected new oid entry paths to include data/file.bin, got %v", secondEntry.Paths) + } +} + +// deprecated test case: now that we always "trust" the client-provided SHA256, this case is not applicable +//func TestRunAddURL_SHA256Mismatch(t *testing.T) { +// ... +//} + +func stubAddURLDeps( + t *testing.T, + service *AddURLService, + inspectFn func(context.Context, cloud.S3ObjectParameters) (*cloud.S3Object, error), + isTrackedFn func(string) (bool, error), + downloadFn func(context.Context, *cloud.S3Object, cloud.S3ObjectParameters, string) (string, string, error), +) func() { + t.Helper() + origInspect := service.inspectS3 + origIsTracked := service.isLFSTracked + origDownload := service.download + + service.inspectS3 = inspectFn + service.isLFSTracked = isTrackedFn + service.download = downloadFn + + return func() { + service.inspectS3 = origInspect + service.isLFSTracked = origIsTracked + service.download = origDownload + } +} + +func requireFlags(t *testing.T, cmd *cobra.Command) { + t.Helper() + if err := cmd.Flags().Set(cloud.AWS_KEY_FLAG_NAME, "key"); err != nil { + t.Fatalf("set aws key: %v", err) + } + if err := cmd.Flags().Set(cloud.AWS_SECRET_FLAG_NAME, "secret"); err != nil { + t.Fatalf("set aws secret: %v", err) + } + if err := cmd.Flags().Set(cloud.AWS_REGION_FLAG_NAME, "region"); err != nil { + t.Fatalf("set aws region: %v", err) + } +} + +func mustChdir(t *testing.T, dir string) string { + t.Helper() + old, err := os.Getwd() + if err != nil { + // non-fatal since we just want to return something that won't cause Chdir to fail, but log it for visibility + // The issue is specific to GitHub Actions because cleanup order can vary between test runners. The temp directory cleanup happens before your t.Cleanup runs, making the current directory invalid. + t.Logf("An error occurred trying to Getwd, continuing with '' old dir: %v", err) + old = "" + } + if err := os.Chdir(dir); err != nil { + t.Fatalf("A fatal error occurred tying Chdir(%s): %v", dir, err) + } + return old +} + +func setupGitRepo(t *testing.T) string { + t.Helper() + dir := t.TempDir() + gitCmd(t, dir, "init") + gitCmd(t, dir, "config", "user.email", "test@example.com") + gitCmd(t, dir, "config", "user.name", "Test User") + return dir +} + +func gitCmd(t *testing.T, dir string, args ...string) { + t.Helper() + cmd := exec.Command("git", args...) + cmd.Dir = dir + out, err := cmd.CombinedOutput() + if err != nil { + t.Fatalf("git %s failed: %v (%s)", strings.Join(args, " "), err, string(out)) } } diff --git a/cmd/addurl/service.go b/cmd/addurl/service.go new file mode 100644 index 00000000..ada4d31d --- /dev/null +++ b/cmd/addurl/service.go @@ -0,0 +1,161 @@ +package addurl + +import ( + "context" + "fmt" + "log/slog" + "os" + "path/filepath" + + "github.com/calypr/data-client/drs" + "github.com/calypr/git-drs/cloud" + "github.com/calypr/git-drs/config" + "github.com/calypr/git-drs/drslog" + "github.com/calypr/git-drs/drsmap" + "github.com/calypr/git-drs/lfs" + "github.com/spf13/cobra" +) + +// AddURLService groups injectable dependencies used to implement the add-url +// behavior (logger factory, S3 inspection, LFS helpers, config loader, etc.). +type AddURLService struct { + newLogger func(string, bool) (*slog.Logger, error) + inspectS3 func(ctx context.Context, input cloud.S3ObjectParameters) (*cloud.S3Object, error) + isLFSTracked func(path string) (bool, error) + getGitRoots func(ctx context.Context) (string, string, error) + gitLFSTrack func(ctx context.Context, path string) (bool, error) + loadConfig func() (*config.Config, error) + download func(ctx context.Context, info *cloud.S3Object, input cloud.S3ObjectParameters, lfsRoot string) (string, string, error) +} + +// NewAddURLService constructs an AddURLService populated with production +// implementations of its dependencies. +func NewAddURLService() *AddURLService { + return &AddURLService{ + newLogger: drslog.NewLogger, + inspectS3: cloud.InspectS3ForLFS, + isLFSTracked: lfs.IsLFSTracked, + getGitRoots: lfs.GetGitRootDirectories, + gitLFSTrack: lfs.GitLFSTrackReadOnly, + loadConfig: config.LoadConfig, + download: download, + } +} + +// Run executes the add-url workflow: parse CLI input, inspect the S3 object, +// ensure the LFS object exists in local storage, write a Git LFS pointer file, +// update the pre-commit cache (best-effort), optionally add a git-lfs track +// entry, and record the DRS mapping. +func (s *AddURLService) Run(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + if ctx == nil { + ctx = context.Background() + } + + logger, err := s.newLogger("", false) + if err != nil { + return fmt.Errorf("error creating logger: %v", err) + } + + input, err := parseAddURLInput(cmd, args) + if err != nil { + return err + } + + s3Info, err := s.inspectS3(ctx, input.s3Params) + if err != nil { + return err + } + + isTracked, err := s.isLFSTracked(input.path) + if err != nil { + return fmt.Errorf("check LFS tracking for %s: %w", input.path, err) + } + + gitCommonDir, lfsRoot, err := s.getGitRoots(ctx) + if err != nil { + return fmt.Errorf("get git root directories: %w", err) + } + + if err := printResolvedInfo(cmd, gitCommonDir, lfsRoot, s3Info, input.path, isTracked, input.sha256); err != nil { + return err + } + + oid, err := s.ensureLFSObject(ctx, s3Info, input, lfsRoot) + if err != nil { + return err + } + + if err := writePointerFile(input.path, oid, s3Info.SizeBytes); err != nil { + return err + } + + if err := updatePrecommitCache(ctx, logger, input.path, oid, input.s3URL); err != nil { + logger.Warn("pre-commit cache update skipped", "error", err) + } + + if err := maybeTrackLFS(ctx, s.gitLFSTrack, input.path, isTracked); err != nil { + return err + } + + cfg, err := s.loadConfig() + if err != nil { + return fmt.Errorf("error getting config: %v", err) + } + + remote, err := cfg.GetDefaultRemote() + if err != nil { + return err + } + + remoteConfig := cfg.GetRemote(remote) + if remoteConfig == nil { + return fmt.Errorf("error getting remote configuration for %s", remote) + } + + builder := drs.NewObjectBuilder(remoteConfig.GetBucketName(), remoteConfig.GetProjectId()) + + file := lfs.LfsFileInfo{ + Name: input.path, + Size: s3Info.SizeBytes, + Oid: oid, + } + if _, err := drsmap.WriteDrsFile(builder, file, &input.s3URL); err != nil { + return fmt.Errorf("error WriteDrsFile: %v", err) + } + + return nil +} + +// ensureLFSObject ensures the LFS object identified by s3Info exists in the +// repository's LFS storage. If the input includes an explicit SHA256 that is +// returned immediately; otherwise the object is downloaded into a temporary +// file and moved into the LFS `objects` storage, returning the object's oid. +func (s *AddURLService) ensureLFSObject(ctx context.Context, s3Info *cloud.S3Object, input addURLInput, lfsRoot string) (string, error) { + if input.sha256 != "" { + return input.sha256, nil + } + + computedSHA, tmpObj, err := s.download(ctx, s3Info, input.s3Params, lfsRoot) + if err != nil { + return "", err + } + + oid := computedSHA + dstDir := filepath.Join(lfsRoot, "objects", oid[0:2], oid[2:4]) + dstObj := filepath.Join(dstDir, oid) + + if err := os.MkdirAll(dstDir, 0755); err != nil { + return "", fmt.Errorf("mkdir %s: %w", dstDir, err) + } + + if err := os.Rename(tmpObj, dstObj); err != nil { + return "", fmt.Errorf("rename %s to %s: %w", tmpObj, dstObj, err) + } + + if _, err := fmt.Fprintf(os.Stderr, "Added data file at %s\n", dstObj); err != nil { + return "", fmt.Errorf("stderr write: %w", err) + } + + return computedSHA, nil +} diff --git a/cmd/cache/cache_test.go b/cmd/cache/cache_test.go deleted file mode 100644 index b33d9c05..00000000 --- a/cmd/cache/cache_test.go +++ /dev/null @@ -1,46 +0,0 @@ -package cache - -import ( - "os" - "path/filepath" - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestCacheArgs(t *testing.T) { - // Test with 1 argument (valid) - err := Cmd.Args(Cmd, []string{"manifest.tsv"}) - assert.NoError(t, err) - - // Test with no arguments (invalid) - err = Cmd.Args(Cmd, []string{}) - assert.Error(t, err) - - // Test with multiple arguments (invalid) - err = Cmd.Args(Cmd, []string{"m1.tsv", "m2.tsv"}) - assert.Error(t, err) -} - -func TestCacheRun(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "cache-test-*") - assert.NoError(t, err) - defer os.RemoveAll(tmpDir) - - originalDir, _ := os.Getwd() - os.Chdir(tmpDir) - defer os.Chdir(originalDir) - - validSHA := "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" - manifestPath := filepath.Join(tmpDir, "manifest.tsv") - content := "files.sha256\tfiles.drs_uri\n" + validSHA + "\tdrs://example.com:obj1\n" - err = os.WriteFile(manifestPath, []byte(content), 0644) - assert.NoError(t, err) - - err = Cmd.RunE(Cmd, []string{manifestPath}) - assert.NoError(t, err) - - // Verify cache directory was created - _, err = os.Stat(".git/drs/objects") - assert.NoError(t, err) -} diff --git a/cmd/cache/create-cache.go b/cmd/cache/create-cache.go deleted file mode 100644 index aaf786a6..00000000 --- a/cmd/cache/create-cache.go +++ /dev/null @@ -1,97 +0,0 @@ -package cache - -import ( - "encoding/csv" - "fmt" - "io" - "os" - "path/filepath" - - "github.com/calypr/git-drs/common" - "github.com/calypr/git-drs/drsmap" - "github.com/spf13/cobra" -) - -var Cmd = &cobra.Command{ - Use: "create-cache ", - Short: "create a local version of a file manifest containing DRS URIs", - Long: "create a local version of a file manifest containing DRS URIs. Enables LFS to map its file object id (sha256) back to a DRS URI by file", - Args: cobra.ExactArgs(1), - RunE: func(cmd *cobra.Command, args []string) error { - file := args[0] - - // load file - f, err := os.Open(file) - if err != nil { - return fmt.Errorf("failed to open manifest file: %w", err) - } - defer f.Close() - - // Use encoding/csv with tab delimiter for TSV - r := csv.NewReader(f) - r.Comma = '\t' - - // Read header - header, err := r.Read() - if err != nil { - return fmt.Errorf("failed to read header: %w", err) - } - - // Map column names to indices - colIdx := make(map[string]int, len(header)) - for i, col := range header { - colIdx[col] = i - } - - // Check required columns - shaIdx, shaOk := colIdx["files.sha256"] - drsIdx, drsOk := colIdx["files.drs_uri"] - if !shaOk || !drsOk { - return fmt.Errorf("manifest must contain 'files.sha256' and 'files.drs_uri' columns") - } - - // Read each row - for { - row, err := r.Read() - if err != nil { - if err == io.EOF { - break - } - return fmt.Errorf("error reading manifest file: %w", err) - } - sha := row[shaIdx] - drsURI := row[drsIdx] - fmt.Printf("Indexing DRS URI %s with sha256 %s\n", drsURI, sha) - - // create sha to DRS URI mapping - objPath, err := drsmap.GetObjectPath(common.DRS_REF_DIR, sha) - if err != nil { - return fmt.Errorf("failed to get object path for %s: %w", sha, err) - } - - if err := os.MkdirAll(filepath.Dir(objPath), 0755); err != nil { - return fmt.Errorf("failed to create dir for %s: %w", objPath, err) - } - - contents := fmt.Sprintf("files.drs_uri\n%s\n", drsURI) - if err := os.WriteFile(objPath, []byte(contents), 0644); err != nil { - return fmt.Errorf("failed to write DRS URI for %s: %w", sha, err) - } - - // Split DRS URI into a custom path and write sha to custom path - customPath, err := drsmap.CreateCustomPath(common.DRS_REF_DIR, drsURI) - if err != nil { - return fmt.Errorf("failed to create custom path for %s: %w", drsURI, err) - } - if err := os.MkdirAll(filepath.Dir(customPath), 0755); err != nil { - return fmt.Errorf("failed to create dir for %s: %w", customPath, err) - } - if err := os.WriteFile(customPath, []byte(sha), 0644); err != nil { - return fmt.Errorf("failed to write sha for %s: %w", drsURI, err) - } - } - - fmt.Printf("Cache created in %s\n", common.DRS_REF_DIR) - return nil - }, -} diff --git a/cmd/delete/main.go b/cmd/delete/main.go index 3a2ef875..10ef0719 100644 --- a/cmd/delete/main.go +++ b/cmd/delete/main.go @@ -5,7 +5,7 @@ import ( "fmt" "os" - "github.com/calypr/data-client/indexd/hash" + "github.com/calypr/data-client/hash" "github.com/calypr/git-drs/common" "github.com/calypr/git-drs/config" "github.com/calypr/git-drs/drslog" diff --git a/cmd/download/main.go b/cmd/download/main.go deleted file mode 100644 index a4c0db33..00000000 --- a/cmd/download/main.go +++ /dev/null @@ -1,99 +0,0 @@ -package download - -import ( - "context" - "fmt" - - dataClientCommon "github.com/calypr/data-client/common" - "github.com/calypr/data-client/download" - "github.com/calypr/data-client/indexd/hash" - "github.com/calypr/git-drs/common" - "github.com/calypr/git-drs/config" - "github.com/calypr/git-drs/drslog" - "github.com/calypr/git-drs/drsmap" - "github.com/spf13/cobra" -) - -var ( - dstPath string - remote string -) - -// Cmd line declaration -// Cmd line declaration -var Cmd = &cobra.Command{ - Use: "download ", - Short: "Download file using file object ID", - Long: "Download file using file object ID (sha256 hash). Use lfs ls-files to get oid", - Args: func(cmd *cobra.Command, args []string) error { - if len(args) != 1 { - cmd.SilenceUsage = false - return fmt.Errorf("error: requires exactly 1 argument (file object ID), received %d\n\nUsage: %s\n\nSee 'git drs download --help' for more details", len(args), cmd.UseLine()) - } - return nil - }, - RunE: func(cmd *cobra.Command, args []string) error { - logger := drslog.GetLogger() - - oid := args[0] - - cfg, err := config.LoadConfig() - if err != nil { - return fmt.Errorf("error loading config: %v", err) - } - - remoteName, err := cfg.GetRemoteOrDefault(remote) - if err != nil { - return fmt.Errorf("error getting default remote: %v", err) - } - - drsClient, err := cfg.GetRemoteClient(remoteName, logger) - if err != nil { - logger.Error(fmt.Sprintf("\nerror creating DRS client: %s", err)) - return err - } - - // get the matching record for this OID - checksumSpec := &hash.Checksum{Type: hash.ChecksumTypeSHA256, Checksum: oid} - records, err := drsClient.GetObjectByHash(context.Background(), checksumSpec) - if err != nil { - return fmt.Errorf("Error looking up OID %s: %v", oid, err) - } - - matchingRecord, err := drsmap.FindMatchingRecord(records, drsClient.GetProjectId()) - if err != nil { - return fmt.Errorf("Error finding matching record for project %s: %v", drsClient.GetProjectId(), err) - } - if matchingRecord == nil { - return fmt.Errorf("No matching record found for project %s and OID %s", drsClient.GetProjectId(), oid) - } - - // download url to destination path or LFS objects if not specified - if dstPath == "" { - dstPath, err = drsmap.GetObjectPath(common.LFS_OBJS_PATH, oid) - } - if err != nil { - return fmt.Errorf("Error getting destination path for OID %s: %v", oid, err) - } - - ctx := dataClientCommon.WithOid(context.Background(), oid) - err = download.DownloadToPath( - ctx, - drsClient.GetGen3Interface(), - matchingRecord.Id, - dstPath, - ) - if err != nil { - return fmt.Errorf("Error downloading file for OID %s (GUID: %s): %v", oid, matchingRecord.Id, err) - } - - logger.Debug("file downloaded") - - return nil - }, -} - -func init() { - Cmd.Flags().StringVarP(&remote, "remote", "r", "", "target remote DRS server (default: default_remote)") - Cmd.Flags().StringVarP(&dstPath, "dst", "d", "", "Destination path to save the downloaded file") -} diff --git a/cmd/download/main_test.go b/cmd/download/main_test.go deleted file mode 100644 index 643114ef..00000000 --- a/cmd/download/main_test.go +++ /dev/null @@ -1,15 +0,0 @@ -package download - -import ( - "testing" - - "github.com/calypr/git-drs/internal/testutils" -) - -func TestDownloadCmd(t *testing.T) { - testutils.RunCmdMainTest(t, "download") -} - -func TestValidateArgs(t *testing.T) { - testutils.RunCmdArgsTest(t) -} diff --git a/cmd/initialize/main.go b/cmd/initialize/main.go index 9d534758..5599c94a 100644 --- a/cmd/initialize/main.go +++ b/cmd/initialize/main.go @@ -4,8 +4,10 @@ import ( "fmt" "log/slog" "os" + "os/exec" "path/filepath" "strconv" + "strings" "time" "github.com/calypr/git-drs/config" @@ -66,6 +68,11 @@ var Cmd = &cobra.Command{ if err != nil { return fmt.Errorf("error installing pre-push hook: %v", err) } + // install pre-commit hook + err = installPreCommitHook(logg) + if err != nil { + return fmt.Errorf("error installing pre-commit hook: %v", err) + } // final logs logg.Debug("Git DRS initialized") @@ -155,3 +162,49 @@ exec git lfs pre-push "$remote" "$url" < "$TMPFILE" logger.Debug("pre-push hook installed") return nil } + +func installPreCommitHook(logger *slog.Logger) error { + cmd := exec.Command("git", "rev-parse", "--git-dir") + cmdOut, err := cmd.Output() + if err != nil { + return fmt.Errorf("unable to locate git directory: %w", err) + } + gitDir := strings.TrimSpace(string(cmdOut)) + hooksDir := filepath.Join(gitDir, "hooks") + if err := os.MkdirAll(hooksDir, 0755); err != nil { + return fmt.Errorf("unable to create hooks directory: %w", err) + } + + hookPath := filepath.Join(hooksDir, "pre-commit") + hookBody := ` +# .git/hooks/pre-commit +exec git drs precommit +` + hookScript := "#!/bin/sh\n" + hookBody + + existingContent, err := os.ReadFile(hookPath) + if err == nil { + // there is an existing hook, rename it, and let the user know + // Backup existing hook with timestamp + timestamp := time.Now().Format("20060102T150405") + backupPath := hookPath + "." + timestamp + if err := os.WriteFile(backupPath, existingContent, 0644); err != nil { + return fmt.Errorf("unable to back up existing pre-commit hook: %w", err) + } + if err := os.Remove(hookPath); err != nil { + return fmt.Errorf("unable to remove hook after backing up: %w", err) + } + logger.Debug(fmt.Sprintf("pre-commit hook updated; backup written to %s", backupPath)) + } + // If there was an error other than expected not existing, return it + if err != nil && !os.IsNotExist(err) { + return fmt.Errorf("unable to read pre-commit hook: %w", err) + } + + err = os.WriteFile(hookPath, []byte(hookScript), 0755) + if err != nil { + return fmt.Errorf("unable to write pre-commit hook: %w", err) + } + logger.Debug("pre-commit hook installed") + return nil +} diff --git a/cmd/initialize/main_test.go b/cmd/initialize/main_test.go index 0da8caf9..26caaa07 100644 --- a/cmd/initialize/main_test.go +++ b/cmd/initialize/main_test.go @@ -33,6 +33,28 @@ func TestInstallPrePushHook(t *testing.T) { } } +func TestInstallPreCommitHook(t *testing.T) { + testutils.SetupTestGitRepo(t) + logger := drslog.NewNoOpLogger() + + if err := installPreCommitHook(logger); err != nil { + t.Fatalf("installPreCommitHook error: %v", err) + } + + hookPath := filepath.Join(".git", "hooks", "pre-commit") + content, err := os.ReadFile(hookPath) + if err != nil { + t.Fatalf("read hook: %v", err) + } + if !strings.Contains(string(content), "git drs precommit") { + t.Fatalf("expected hook to contain git drs precommit") + } + + if err := installPreCommitHook(logger); err != nil { + t.Fatalf("installPreCommitHook second call error: %v", err) + } +} + func TestInitGitConfig(t *testing.T) { testutils.SetupTestGitRepo(t) transfers = 2 diff --git a/cmd/list/list_test.go b/cmd/list/list_test.go deleted file mode 100644 index 3c75104c..00000000 --- a/cmd/list/list_test.go +++ /dev/null @@ -1,39 +0,0 @@ -package list - -import ( - "testing" - - "github.com/calypr/git-drs/internal/testutils" - "github.com/stretchr/testify/assert" -) - -func TestListCmdArgs(t *testing.T) { - err := Cmd.Args(Cmd, []string{}) - assert.NoError(t, err) - - err = Cmd.Args(Cmd, []string{"extra"}) - assert.Error(t, err) -} - -func TestListProjectCmdArgs(t *testing.T) { - err := ListProjectCmd.Args(ListProjectCmd, []string{"project-id"}) - assert.NoError(t, err) - - err = ListProjectCmd.Args(ListProjectCmd, []string{}) - assert.Error(t, err) -} - -func TestListRun_Error(t *testing.T) { - _ = testutils.SetupTestGitRepo(t) - // No config - err := Cmd.RunE(Cmd, []string{}) - assert.Error(t, err) -} - -func TestListRun_InvalidRemote(t *testing.T) { - tmpDir := testutils.SetupTestGitRepo(t) - testutils.CreateDefaultTestConfig(t, tmpDir) - remote = "invalid" - err := Cmd.RunE(Cmd, []string{}) - assert.Error(t, err) -} diff --git a/cmd/list/main.go b/cmd/list/main.go deleted file mode 100644 index d00e7555..00000000 --- a/cmd/list/main.go +++ /dev/null @@ -1,189 +0,0 @@ -package list - -import ( - "context" - "fmt" - "io" - "os" - - "github.com/bytedance/sonic" - "github.com/calypr/data-client/indexd/drs" - "github.com/calypr/data-client/indexd/hash" - "github.com/calypr/git-drs/config" - "github.com/calypr/git-drs/drslog" - "github.com/spf13/cobra" -) - -var ( - outJson bool = false - outFile string - listOutFile string - remote string -) - -var checksumPref = []hash.ChecksumType{hash.ChecksumTypeSHA256, hash.ChecksumTypeMD5, hash.ChecksumTypeETag} - -func getChecksumPos(q hash.ChecksumType, a []hash.ChecksumType) int { - for i, s := range a { - if q == s { - return i - } - } - return -1 -} - -// Pick out the most preferred checksum to display -func getCheckSumStr(obj drs.DRSObject) string { - curPos := len(checksumPref) + 1 - curVal := "" - for checksumType, checksum := range hash.ConvertHashInfoToMap(obj.Checksums) { - c := getChecksumPos(hash.ChecksumType(checksumType), checksumPref) - if c != -1 && c < curPos { - curPos = c - curVal = checksumType + ":" + checksum - } - } - return curVal -} - -// Cmd line declaration -var Cmd = &cobra.Command{ - Use: "list", - Short: "List DRS entities from server", - Args: func(cmd *cobra.Command, args []string) error { - if len(args) != 0 { - cmd.SilenceUsage = false - return fmt.Errorf("error: accepts no arguments, received %d\n\nUsage: %s\n\nSee 'git drs list --help' for more details", len(args), cmd.UseLine()) - } - return nil - }, - RunE: func(cmd *cobra.Command, args []string) error { - logger := drslog.GetLogger() - - var outWriter io.Writer - if listOutFile != "" { - f, err := os.Create(listOutFile) - if err != nil { - return err - } - defer f.Close() - outWriter = f - } else { - outWriter = os.Stdout - } - - conf, err := config.LoadConfig() - if err != nil { - return fmt.Errorf("error loading config: %v", err) - } - - remoteName, err := conf.GetRemoteOrDefault(remote) - if err != nil { - return fmt.Errorf("error getting default remote: %v", err) - } - - client, err := conf.GetRemoteClient(remoteName, logger) - if err != nil { - logger.Debug("Client failed") - return err - } - objChan, err := client.ListObjects(context.Background()) - if err != nil { - return err - } - if !outJson { - fmt.Fprintf(outWriter, "%-55s\t%-15s\t%-75s\t%s\n", "URI", "Size", "Checksum", "Name") - } - - // for each result, check for error and print - for objResult := range objChan { - if objResult.Error != nil { - return objResult.Error - } - obj := objResult.Object - if outJson { - out, err := sonic.ConfigFastest.Marshal(*obj) - if err != nil { - return err - } - fmt.Fprintf(outWriter, "%s\n", string(out)) - } else { - fmt.Fprintf(outWriter, "%s\t%-15d\t%-75s\t%s\n", obj.SelfURI, obj.Size, getCheckSumStr(*obj), obj.Name) - } - } - return nil - }, -} -var ListProjectCmd = &cobra.Command{ - Use: "list-project ", - Short: "List DRS entities from server", - Args: func(cmd *cobra.Command, args []string) error { - if len(args) != 1 { - cmd.SilenceUsage = false - return fmt.Errorf("error: requires exactly 1 argument (project ID), received %d\n\nUsage: %s\n\nSee 'git drs list-project --help' for more details", len(args), cmd.UseLine()) - } - return nil - }, - RunE: func(cmd *cobra.Command, args []string) error { - logger := drslog.GetLogger() - - conf, err := config.LoadConfig() - if err != nil { - return fmt.Errorf("error loading config: %v", err) - } - - remoteName, err := conf.GetRemoteOrDefault(remote) - if err != nil { - return fmt.Errorf("error getting default remote: %v", err) - } - - client, err := conf.GetRemoteClient(remoteName, logger) - if err != nil { - return err - } - objChan, err := client.ListObjectsByProject(context.Background(), args[0]) - if err != nil { - return err - } - - var f *os.File - var outWriter io.Writer - if outFile != "" { - f, err = os.Create(outFile) - if err != nil { - return err - } - defer f.Close() - outWriter = f - } else { - outWriter = os.Stdout - } - for objResult := range objChan { - if objResult.Error != nil { - return objResult.Error - } - obj := objResult.Object - out, err := sonic.ConfigFastest.Marshal(*obj) - if err != nil { - return err - } - _, err = outWriter.Write(out) - if err != nil { - return err - } - _, err = outWriter.Write([]byte("\n")) - if err != nil { - return err - } - } - return nil - }, -} - -func init() { - ListProjectCmd.Flags().StringVarP(&remote, "remote", "r", "", "target remote DRS server (default: default_remote)") - ListProjectCmd.Flags().StringVarP(&outFile, "out", "o", outFile, "File path to save output to") - Cmd.Flags().StringVarP(&remote, "remote", "r", "", "target remote DRS server (default: default_remote)") - Cmd.Flags().StringVarP(&listOutFile, "out", "o", listOutFile, "File path to save output to") - Cmd.Flags().BoolVarP(&outJson, "json", "j", outJson, "Output formatted as JSON") -} diff --git a/cmd/list/main_test.go b/cmd/list/main_test.go deleted file mode 100644 index d95515f0..00000000 --- a/cmd/list/main_test.go +++ /dev/null @@ -1,25 +0,0 @@ -package list - -import ( - "testing" - - "github.com/calypr/data-client/indexd/drs" - "github.com/calypr/data-client/indexd/hash" -) - -func TestGetChecksumPos(t *testing.T) { - if pos := getChecksumPos(hash.ChecksumTypeSHA256, checksumPref); pos != 0 { - t.Fatalf("expected SHA256 at pos 0, got %d", pos) - } - if pos := getChecksumPos(hash.ChecksumType("missing"), checksumPref); pos != -1 { - t.Fatalf("expected missing checksum -1, got %d", pos) - } -} - -func TestGetCheckSumStr(t *testing.T) { - obj := drs.DRSObject{Checksums: hash.HashInfo{MD5: "md5", SHA256: "sha"}} - value := getCheckSumStr(obj) - if value != "sha256:sha" { - t.Fatalf("unexpected checksum string: %s", value) - } -} diff --git a/cmd/listconfig/listconfig_test.go b/cmd/listconfig/listconfig_test.go deleted file mode 100644 index f59660e3..00000000 --- a/cmd/listconfig/listconfig_test.go +++ /dev/null @@ -1,23 +0,0 @@ -package listconfig - -import ( - "testing" - - "github.com/calypr/git-drs/internal/testutils" - "github.com/stretchr/testify/assert" -) - -func TestListConfigCmdArgs(t *testing.T) { - err := Cmd.Args(Cmd, []string{}) - assert.NoError(t, err) - - err = Cmd.Args(Cmd, []string{"extra"}) - assert.Error(t, err) -} - -func TestListConfigRun_Error(t *testing.T) { - _ = testutils.SetupTestGitRepo(t) - // No config should not error, just return empty - err := Cmd.RunE(Cmd, []string{}) - assert.NoError(t, err) -} diff --git a/cmd/listconfig/main.go b/cmd/listconfig/main.go deleted file mode 100644 index fc34feb0..00000000 --- a/cmd/listconfig/main.go +++ /dev/null @@ -1,54 +0,0 @@ -package listconfig - -import ( - "fmt" - "os" - - "github.com/bytedance/sonic" - "github.com/calypr/git-drs/config" - "github.com/spf13/cobra" - "gopkg.in/yaml.v3" -) - -var ( - jsonOutput bool -) - -// Cmd represents the list-config command -var Cmd = &cobra.Command{ - Use: "list-config", - Short: "Display the current configuration", - Long: "Pretty prints the current configuration file in YAML format", - Args: func(cmd *cobra.Command, args []string) error { - if len(args) != 0 { - cmd.SilenceUsage = false - return fmt.Errorf("error: accepts no arguments, received %d\n\nUsage: %s\n\nSee 'git drs list-config --help' for more details", len(args), cmd.UseLine()) - } - return nil - }, - RunE: func(cmd *cobra.Command, args []string) error { - // Load the current configuration - cfg, err := config.LoadConfig() - if err != nil { - return fmt.Errorf("failed to load config: %w", err) - } - - if jsonOutput { - // Output as JSON if requested - encoder := sonic.ConfigFastest.NewEncoder(os.Stdout) - encoder.SetIndent("", " ") - return encoder.Encode(cfg) - } else { - // Default YAML output with nice formatting - encoder := yaml.NewEncoder(os.Stdout) - encoder.SetIndent(2) - defer encoder.Close() - - return encoder.Encode(cfg) - } - }, -} - -func init() { - Cmd.Flags().BoolVarP(&jsonOutput, "json", "j", false, "output in JSON format") -} diff --git a/cmd/precommit/main.go b/cmd/precommit/main.go new file mode 100644 index 00000000..ac2f3047 --- /dev/null +++ b/cmd/precommit/main.go @@ -0,0 +1,547 @@ +// Package precommit +// ------------------------------------- +// LFS-only local cache updater for: +// - Path -> OID : .git/drs/pre-commit/v1/paths/.json +// - OID -> Paths + S3 URL hint : .git/drs/pre-commit/v1/oids/.json +// +// This hook is intentionally: +// - LFS-only (non-LFS paths are ignored) +// - local-only (no network, no server index reads) +// - index-based (reads STAGED content via `git show :`) +// +// Note: This is a reference implementation. Adjust logging/policy as desired. +package precommit + +import ( + "bufio" + "bytes" + "context" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "os" + "os/exec" + "path/filepath" + "sort" + "strings" + "time" + + "github.com/spf13/cobra" +) + +const ( + cacheVersionDir = "drs/pre-commit/v1" + lfsSpecLine = "version https://git-lfs.github.com/spec/v1" +) + +type PathEntry struct { + Path string `json:"path"` + LFSOID string `json:"lfs_oid"` + UpdatedAt string `json:"updated_at"` +} + +type OIDEntry struct { + LFSOID string `json:"lfs_oid"` + Paths []string `json:"paths"` + S3URL string `json:"s3_url,omitempty"` // hint only; may be empty + UpdatedAt string `json:"updated_at"` + ContentChange bool `json:"content_changed"` +} + +type ChangeKind int + +const ( + KindAdd ChangeKind = iota + KindModify + KindDelete + KindRename +) + +type Change struct { + Kind ChangeKind + OldPath string // for rename + NewPath string // for rename (and for add/modify/delete uses NewPath) + Status string // raw status, e.g. "A", "M", "D", "R100" +} + +// Cmd line declaration +var Cmd = &cobra.Command{ + Use: "precommit", + Short: "pre-commit hook to update local DRS cache", + Long: "Pre-commit hook that updates the local DRS pre-commit cache", + Args: cobra.ExactArgs(0), + RunE: func(cmd *cobra.Command, args []string) error { + return run(context.Background()) + }, +} + +func main() { + ctx := context.Background() + if err := run(ctx); err != nil { + // For a reference impl, treat errors as non-fatal unless you want strict enforcement. + // Exiting non-zero blocks the commit. + fmt.Fprintf(os.Stderr, "pre-commit drs cache: %v\n", err) + os.Exit(1) + } +} + +func run(ctx context.Context) error { + gitDir, err := gitRevParseGitDir(ctx) + if err != nil { + return err + } + + cacheRoot := filepath.Join(gitDir, cacheVersionDir) + pathsDir := filepath.Join(cacheRoot, "paths") + oidsDir := filepath.Join(cacheRoot, "oids") + tombsDir := filepath.Join(cacheRoot, "tombstones") + + if err := os.MkdirAll(pathsDir, 0o755); err != nil { + return err + } + if err := os.MkdirAll(oidsDir, 0o755); err != nil { + return err + } + _ = os.MkdirAll(tombsDir, 0o755) // optional + + changes, err := stagedChanges(ctx) + if err != nil { + return err + } + if len(changes) == 0 { + return nil + } + + now := time.Now().UTC().Format(time.RFC3339) + + // Process renames first so subsequent add/modify logic sees the "new" path. + // This mirrors how we want cache paths to follow staged paths. + for _, ch := range changes { + if ch.Kind != KindRename { + continue + } + // Only act if BOTH old and new are LFS in scope? Prefer: + // - If the new path is LFS, we migrate. + // - If it isn't LFS, we remove old path entry (out of scope). + newOID, newIsLFS, err := stagedLFSOID(ctx, ch.NewPath) + if err != nil { + // If file doesn't exist in index due to weird staging, skip. + continue + } + + oldPathFile := pathEntryFile(pathsDir, ch.OldPath) + newPathFile := pathEntryFile(pathsDir, ch.NewPath) + + if newIsLFS { + // Move/overwrite path entry file + if err := moveFileBestEffort(oldPathFile, newPathFile); err != nil && !errors.Is(err, os.ErrNotExist) { + return fmt.Errorf("rename migrate path entry: %w", err) + } + + // Ensure path entry content correct + if err := writeJSONAtomic(newPathFile, PathEntry{ + Path: ch.NewPath, + LFSOID: newOID, + UpdatedAt: now, + }); err != nil { + return err + } + + // Update oid entry: replace old path with new path for that OID + if err := oidAddOrReplacePath(oidsDir, newOID, ch.OldPath, ch.NewPath, now, false); err != nil { + return err + } + } else { + // Out of scope now: remove any cached path entry. + _ = os.Remove(oldPathFile) + } + } + + // Process adds/modifies/deletes (and renames again just to ensure content correctness on new path). + for _, ch := range changes { + switch ch.Kind { + case KindAdd, KindModify: + if err := handleUpsert(ctx, pathsDir, oidsDir, ch.NewPath, now); err != nil { + return err + } + case KindRename: + // Treat like upsert on NewPath to ensure OID/path consistency if content also changed. + if err := handleUpsert(ctx, pathsDir, oidsDir, ch.NewPath, now); err != nil { + return err + } + // Optionally also remove old path from *other* OID entry if rename+content-change changed OID. + // We'll do it inside handleUpsert by checking previous cached OID for that path (after move). + case KindDelete: + if err := handleDelete(ctx, pathsDir, oidsDir, tombsDir, ch.NewPath, now); err != nil { + return err + } + } + } + + return nil +} + +func handleUpsert(ctx context.Context, pathsDir, oidsDir, path, now string) error { + oid, isLFS, err := stagedLFSOID(ctx, path) + if err != nil { + // If file isn't in index, ignore. + return nil + } + if !isLFS { + // Out of scope. + return nil + } + + pathFile := pathEntryFile(pathsDir, path) + + // Load previous path entry if it exists to detect content changes. + var prev PathEntry + prevExists := false + if b, err := os.ReadFile(pathFile); err == nil { + _ = json.Unmarshal(b, &prev) + if prev.Path != "" && prev.LFSOID != "" { + prevExists = true + } + } + + // Write/update path entry. + if err := writeJSONAtomic(pathFile, PathEntry{ + Path: path, + LFSOID: oid, + UpdatedAt: now, + }); err != nil { + return err + } + + // Update OID entry for new oid: add path. + contentChanged := prevExists && prev.LFSOID != oid + if err := oidAddOrReplacePath(oidsDir, oid, "", path, now, contentChanged); err != nil { + return err + } + + // If content changed, remove path from the *old* oid entry (best effort). + if contentChanged { + _ = oidRemovePath(oidsDir, prev.LFSOID, path, now) + } + + return nil +} + +func handleDelete(ctx context.Context, pathsDir, oidsDir, tombsDir, path, now string) error { + // Only consider deletion if it was previously an LFS entry (cache-driven). + pathFile := pathEntryFile(pathsDir, path) + b, err := os.ReadFile(pathFile) + if err != nil { + // nothing to do + return nil + } + var pe PathEntry + if err := json.Unmarshal(b, &pe); err != nil { + // corrupted cache; remove it + _ = os.Remove(pathFile) + return nil + } + // Remove path entry. + _ = os.Remove(pathFile) + + // Remove this path from the old oid entry (best effort). + if pe.LFSOID != "" { + _ = oidRemovePath(oidsDir, pe.LFSOID, path, now) + } + + // Optional tombstone. + tombFile := filepath.Join(tombsDir, encodePath(path)+".json") + _ = writeJSONAtomic(tombFile, map[string]string{ + "path": path, + "deleted_at": now, + }) + + return nil +} + +// stagedChanges parses: git diff --cached --name-status -M +// Formats: +// +// Apath +// Mpath +// Dpath +// R100oldnew +func stagedChanges(ctx context.Context) ([]Change, error) { + out, err := git(ctx, "diff", "--cached", "--name-status", "-M") + if err != nil { + return nil, err + } + var changes []Change + sc := bufio.NewScanner(bytes.NewReader(out)) + for sc.Scan() { + line := sc.Text() + if strings.TrimSpace(line) == "" { + continue + } + parts := strings.Split(line, "\t") + if len(parts) < 2 { + continue + } + status := parts[0] + switch { + case status == "A": + changes = append(changes, Change{Kind: KindAdd, NewPath: parts[1], Status: status}) + case status == "M": + changes = append(changes, Change{Kind: KindModify, NewPath: parts[1], Status: status}) + case status == "D": + changes = append(changes, Change{Kind: KindDelete, NewPath: parts[1], Status: status}) + case strings.HasPrefix(status, "R") && len(parts) >= 3: + changes = append(changes, Change{Kind: KindRename, OldPath: parts[1], NewPath: parts[2], Status: status}) + default: + // ignore other statuses (C, T, U, etc) for this reference impl + } + } + if err := sc.Err(); err != nil { + return nil, err + } + return changes, nil +} + +// stagedLFSOID returns (oid, isLFS, err) based on STAGED content. +// isLFS is true only if the staged file is a valid LFS pointer with an oid sha256 line. +func stagedLFSOID(ctx context.Context, path string) (string, bool, error) { + out, err := git(ctx, "show", ":"+path) + if err != nil { + // path may not exist in index (deleted/intent-to-add weirdness) + return "", false, err + } + + // Fast parse: look for spec line and oid line near top. + // LFS pointer files are small; scanning full content is fine. + var hasSpec bool + var oid string + + sc := bufio.NewScanner(bytes.NewReader(out)) + for sc.Scan() { + line := sc.Text() + if line == lfsSpecLine { + hasSpec = true + continue + } + if strings.HasPrefix(line, "oid sha256:") { + hex := strings.TrimPrefix(line, "oid sha256:") + hex = strings.TrimSpace(hex) + if hex != "" { + oid = "sha256:" + hex + } + // keep scanning a bit in case spec is below (rare), but we can break once both are found. + } + // pointer usually has only a few lines; stop early after 10 lines + if hasSpec && oid != "" { + break + } + } + if err := sc.Err(); err != nil { + return "", false, err + } + + if hasSpec && oid != "" { + return oid, true, nil + } + return "", false, nil +} + +func gitRevParseGitDir(ctx context.Context) (string, error) { + out, err := git(ctx, "rev-parse", "--git-dir") + if err != nil { + return "", err + } + gitDir := strings.TrimSpace(string(out)) + if gitDir == "" { + return "", errors.New("could not determine .git dir") + } + // If gitDir is relative, resolve relative to repo root + if !filepath.IsAbs(gitDir) { + rootOut, err := git(ctx, "rev-parse", "--show-toplevel") + if err != nil { + return "", err + } + root := strings.TrimSpace(string(rootOut)) + gitDir = filepath.Join(root, gitDir) + } + return gitDir, nil +} + +func git(ctx context.Context, args ...string) ([]byte, error) { + cmd := exec.CommandContext(ctx, "git", args...) + cmd.Env = os.Environ() + var stdout, stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + if err := cmd.Run(); err != nil { + // include stderr for debugging; don’t leak massive output + msg := strings.TrimSpace(stderr.String()) + if msg == "" { + msg = err.Error() + } + return nil, fmt.Errorf("git %s: %s", strings.Join(args, " "), msg) + } + return stdout.Bytes(), nil +} + +// pathEntryFile maps a repo-relative path to a cache file location. +// We keep a deterministic encoding so any path maps to exactly one file. +func pathEntryFile(pathsDir, path string) string { + return filepath.Join(pathsDir, encodePath(path)+".json") +} + +func encodePath(path string) string { + // base64url encoding of the UTF-8 path string (no padding) is simple and safe. + return base64.RawURLEncoding.EncodeToString([]byte(path)) +} + +func oidEntryFile(oidsDir, oid string) string { + // OID contains ":"; make it filesystem safe but still human readable. + // Use a stable transform; here: sha256 of oid string to avoid path length issues. + sum := sha256.Sum256([]byte(oid)) + return filepath.Join(oidsDir, fmt.Sprintf("%x.json", sum[:])) +} + +// oidAddOrReplacePath: +// - loads oid entry (if exists) +// - adds newPath to paths[] +// - if oldPath != "" and present, replaces it with newPath +// - sets ContentChange flag if requested (ORed into existing flag) +// - preserves existing s3_url hint +func oidAddOrReplacePath(oidsDir, oid, oldPath, newPath, now string, contentChanged bool) error { + f := oidEntryFile(oidsDir, oid) + + entry := OIDEntry{ + LFSOID: oid, + Paths: []string{}, + UpdatedAt: now, + } + if b, err := os.ReadFile(f); err == nil { + _ = json.Unmarshal(b, &entry) + // ensure oid is set even if old file was incomplete + entry.LFSOID = oid + } + + paths := make(map[string]struct{}, len(entry.Paths)+1) + for _, p := range entry.Paths { + paths[p] = struct{}{} + } + + if oldPath != "" { + delete(paths, oldPath) + } + if newPath != "" { + paths[newPath] = struct{}{} + } + + entry.Paths = keysSorted(paths) + entry.UpdatedAt = now + entry.ContentChange = entry.ContentChange || contentChanged + + return writeJSONAtomic(f, entry) +} + +func oidRemovePath(oidsDir, oid, path, now string) error { + f := oidEntryFile(oidsDir, oid) + + b, err := os.ReadFile(f) + if err != nil { + return err + } + var entry OIDEntry + if err := json.Unmarshal(b, &entry); err != nil { + return err + } + paths := make(map[string]struct{}, len(entry.Paths)) + for _, p := range entry.Paths { + if p == path { + continue + } + paths[p] = struct{}{} + } + entry.Paths = keysSorted(paths) + entry.UpdatedAt = now + + // If no paths remain, keep the file (it may still hold s3_url hint) or delete it. + // This ADR allows stale entries; keeping is fine. Optionally delete when empty: + // if len(entry.Paths) == 0 && entry.S3URL == "" { return os.Remove(f) } + + return writeJSONAtomic(f, entry) +} + +func keysSorted(m map[string]struct{}) []string { + out := make([]string, 0, len(m)) + for k := range m { + out = append(out, k) + } + sort.Strings(out) + return out +} + +// writeJSONAtomic writes JSON to a temp file then renames it into place. +// This avoids partially written cache files if the process is interrupted. +func writeJSONAtomic(path string, v any) error { + dir := filepath.Dir(path) + if err := os.MkdirAll(dir, 0o755); err != nil { + return err + } + + tmp := path + ".tmp" + f, err := os.OpenFile(tmp, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0o644) + if err != nil { + return err + } + defer func() { _ = f.Close() }() + + enc := json.NewEncoder(f) + enc.SetIndent("", " ") + if err := enc.Encode(v); err != nil { + _ = os.Remove(tmp) + return err + } + if err := f.Sync(); err != nil { + _ = os.Remove(tmp) + return err + } + if err := f.Close(); err != nil { + _ = os.Remove(tmp) + return err + } + return os.Rename(tmp, path) +} + +func moveFileBestEffort(src, dst string) error { + // Ensure destination directory exists. + if err := os.MkdirAll(filepath.Dir(dst), 0o755); err != nil { + return err + } + // Rename will fail across devices; fall back to copy+remove. + if err := os.Rename(src, dst); err == nil { + return nil + } else if errors.Is(err, os.ErrNotExist) { + return err + } + + in, err := os.Open(src) + if err != nil { + return err + } + defer in.Close() + + out, err := os.OpenFile(dst, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0o644) + if err != nil { + return err + } + + if _, err := io.Copy(out, in); err != nil { + _ = out.Close() + return err + } + if err := out.Close(); err != nil { + return err + } + return os.Remove(src) +} diff --git a/cmd/precommit/main_test.go b/cmd/precommit/main_test.go new file mode 100644 index 00000000..8a0fb0c6 --- /dev/null +++ b/cmd/precommit/main_test.go @@ -0,0 +1,146 @@ +package precommit + +import ( + "context" + "encoding/json" + "os" + "os/exec" + "path/filepath" + "strings" + "testing" + "time" +) + +func TestHandleUpsertIgnoresNonLFSFile(t *testing.T) { + repo := setupGitRepo(t) + oldwd := mustChdir(t, repo) + t.Cleanup(func() { _ = os.Chdir(oldwd) }) + + path := filepath.Join(repo, "data", "file.txt") + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + t.Fatalf("mkdir: %v", err) + } + if err := os.WriteFile(path, []byte("plain content"), 0o644); err != nil { + t.Fatalf("write: %v", err) + } + gitCmd(t, repo, "add", "data/file.txt") + + cacheRoot := filepath.Join(repo, ".git", "drs", "pre-commit", "v1") + pathsDir := filepath.Join(cacheRoot, "paths") + oidsDir := filepath.Join(cacheRoot, "oids") + if err := os.MkdirAll(pathsDir, 0o755); err != nil { + t.Fatalf("mkdir paths: %v", err) + } + if err := os.MkdirAll(oidsDir, 0o755); err != nil { + t.Fatalf("mkdir oids: %v", err) + } + + now := time.Now().UTC().Format(time.RFC3339) + if err := handleUpsert(context.Background(), pathsDir, oidsDir, "data/file.txt", now); err != nil { + t.Fatalf("handleUpsert: %v", err) + } + + pathEntry := pathEntryFile(pathsDir, "data/file.txt") + if _, err := os.Stat(pathEntry); !os.IsNotExist(err) { + t.Fatalf("expected no cache entry for non-LFS file, got err=%v", err) + } +} + +func TestHandleUpsertWritesLFSPointerCache(t *testing.T) { + repo := setupGitRepo(t) + oldwd := mustChdir(t, repo) + t.Cleanup(func() { _ = os.Chdir(oldwd) }) + + path := filepath.Join(repo, "data", "file.bin") + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + t.Fatalf("mkdir: %v", err) + } + lfsPointer := strings.Join([]string{ + "version https://git-lfs.github.com/spec/v1", + "oid sha256:deadbeef", + "size 12", + "", + }, "\n") + if err := os.WriteFile(path, []byte(lfsPointer), 0o644); err != nil { + t.Fatalf("write: %v", err) + } + gitCmd(t, repo, "add", "data/file.bin") + + cacheRoot := filepath.Join(repo, ".git", "drs", "pre-commit", "v1") + pathsDir := filepath.Join(cacheRoot, "paths") + oidsDir := filepath.Join(cacheRoot, "oids") + if err := os.MkdirAll(pathsDir, 0o755); err != nil { + t.Fatalf("mkdir paths: %v", err) + } + if err := os.MkdirAll(oidsDir, 0o755); err != nil { + t.Fatalf("mkdir oids: %v", err) + } + + now := time.Now().UTC().Format(time.RFC3339) + if err := handleUpsert(context.Background(), pathsDir, oidsDir, "data/file.bin", now); err != nil { + t.Fatalf("handleUpsert: %v", err) + } + + pathEntry := pathEntryFile(pathsDir, "data/file.bin") + pathData, err := os.ReadFile(pathEntry) + if err != nil { + t.Fatalf("read path entry: %v", err) + } + var pathCache PathEntry + if err := json.Unmarshal(pathData, &pathCache); err != nil { + t.Fatalf("unmarshal path entry: %v", err) + } + if pathCache.Path != "data/file.bin" { + t.Fatalf("expected path entry to be data/file.bin, got %q", pathCache.Path) + } + if pathCache.LFSOID != "sha256:deadbeef" { + t.Fatalf("expected lfs oid sha256:deadbeef, got %q", pathCache.LFSOID) + } + + oidEntry := oidEntryFile(oidsDir, "sha256:deadbeef") + oidData, err := os.ReadFile(oidEntry) + if err != nil { + t.Fatalf("read oid entry: %v", err) + } + var oidCache OIDEntry + if err := json.Unmarshal(oidData, &oidCache); err != nil { + t.Fatalf("unmarshal oid entry: %v", err) + } + if oidCache.LFSOID != "sha256:deadbeef" { + t.Fatalf("expected oid entry sha256:deadbeef, got %q", oidCache.LFSOID) + } + if len(oidCache.Paths) != 1 || oidCache.Paths[0] != "data/file.bin" { + t.Fatalf("expected oid paths to include data/file.bin, got %v", oidCache.Paths) + } +} + +func setupGitRepo(t *testing.T) string { + t.Helper() + dir := t.TempDir() + gitCmd(t, dir, "init") + gitCmd(t, dir, "config", "user.email", "test@example.com") + gitCmd(t, dir, "config", "user.name", "Test User") + return dir +} + +func mustChdir(t *testing.T, dir string) string { + t.Helper() + old, err := os.Getwd() + if err != nil { + t.Fatalf("Getwd: %v", err) + } + if err := os.Chdir(dir); err != nil { + t.Fatalf("Chdir(%s): %v", dir, err) + } + return old +} + +func gitCmd(t *testing.T, dir string, args ...string) { + t.Helper() + cmd := exec.Command("git", args...) + cmd.Dir = dir + out, err := cmd.CombinedOutput() + if err != nil { + t.Fatalf("git %s failed: %v (%s)", strings.Join(args, " "), err, string(out)) + } +} diff --git a/cmd/prepush/main.go b/cmd/prepush/main.go index b28d60c1..dc28335a 100644 --- a/cmd/prepush/main.go +++ b/cmd/prepush/main.go @@ -2,16 +2,23 @@ package prepush import ( "bufio" + "context" "fmt" "io" + "log/slog" "os" + "os/exec" "sort" "strings" + "time" + "github.com/calypr/data-client/drs" "github.com/calypr/git-drs/client/indexd" "github.com/calypr/git-drs/config" "github.com/calypr/git-drs/drslog" "github.com/calypr/git-drs/drsmap" + "github.com/calypr/git-drs/lfs" + "github.com/calypr/git-drs/precommit_cache" "github.com/spf13/cobra" ) @@ -22,101 +29,315 @@ var Cmd = &cobra.Command{ Long: "Pre-push hook that updates DRS objects before transfer", Args: cobra.RangeArgs(0, 2), RunE: func(cmd *cobra.Command, args []string) error { - //myLogger := drslog.GetLogger() - myLogger, err := drslog.NewLogger("", false) - if err != nil { - return fmt.Errorf("error creating logger: %v", err) - } + return NewPrePushService().Run(args, os.Stdin) + }, +} - myLogger.Debug("~~~~~~~~~~~~~ START: pre-push ~~~~~~~~~~~~~") +type PrePushService struct { + newLogger func(string, bool) (*slog.Logger, error) + loadConfig func() (*config.Config, error) + updateDrsObjects func(drs.ObjectBuilder, map[string]lfs.LfsFileInfo, drsmap.UpdateOptions) error + createTempFile func(dir, pattern string) (*os.File, error) +} - cfg, err := config.LoadConfig() - if err != nil { - return fmt.Errorf("error getting config: %v", err) - } +func NewPrePushService() *PrePushService { + return &PrePushService{ + newLogger: drslog.NewLogger, + loadConfig: config.LoadConfig, + updateDrsObjects: drsmap.UpdateDrsObjectsWithFiles, + createTempFile: os.CreateTemp, + } +} + +func (s *PrePushService) Run(args []string, stdin io.Reader) error { + ctx := context.Background() + myLogger, err := s.newLogger("", false) + if err != nil { + return fmt.Errorf("error creating logger: %v", err) + } + + myLogger.Info("~~~~~~~~~~~~~ START: pre-push ~~~~~~~~~~~~~") + + cfg, err := s.loadConfig() + if err != nil { + return fmt.Errorf("error getting config: %v", err) + } + + gitRemoteName, gitRemoteLocation := parseRemoteArgs(args) + myLogger.Debug(fmt.Sprintf("git remote name: %s, git remote location: %s", gitRemoteName, gitRemoteLocation)) + + remote, err := cfg.GetDefaultRemote() + if err != nil { + myLogger.Debug(fmt.Sprintf("Warning. Error getting default remote: %v", err)) + fmt.Fprintln(os.Stderr, "Warning. Skipping DRS preparation. Error getting default remote:", err) + return nil + } + + // get the remote client + cli, err := cfg.GetRemoteClient(remote, myLogger) + if err != nil { + // Print warning to stderr and return success (exit 0) + fmt.Fprintln(os.Stderr, "Warning. Skipping DRS preparation. Error getting remote client:", err) + myLogger.Debug(fmt.Sprintf("Warning. Skipping DRS preparation. Error getting remote client: %v", err)) + // Check for GitDrsIdxdClient + } + dc, ok := cli.(*indexd.GitDrsIdxdClient) + if !ok { + return fmt.Errorf("cli is not IndexdClient: %T", cli) + } + myLogger.Debug(fmt.Sprintf("Current server: %s", dc.Config.ProjectId)) + remoteConfig := cfg.GetRemote(remote) + if remoteConfig == nil { + fmt.Fprintln(os.Stderr, "Warning. Skipping DRS preparation. Error getting remote configuration.") + myLogger.Debug("Warning. Skipping DRS preparation. Error getting remote configuration.") + return nil + } + + builder := drs.NewObjectBuilder(remoteConfig.GetBucketName(), remoteConfig.GetProjectId()) + myLogger.Debug(fmt.Sprintf("Current server project: %s", builder.ProjectID)) + + tmp, err := bufferStdin(stdin, s.createTempFile) + if err != nil { + myLogger.Error(fmt.Sprintf("error buffering stdin: %v", err)) + return err + } + defer func() { + _ = tmp.Close() + _ = os.Remove(tmp.Name()) + }() + + refs, err := readPushedRefs(tmp) + if err != nil { + myLogger.Error(fmt.Sprintf("error reading pushed refs: %v", err)) + return err + } + branches := branchesFromRefs(refs) - //Command-line arguments: The hook receives two parameters: - //* The name of the remote (e.g., origin). - //* The remote's location/URL (e.g., github.com). - // Create gitRemoteName and gitRemoteLocation from args. - myLogger.Debug(fmt.Sprintf("pre-push args: %v", args)) - var gitRemoteName, gitRemoteLocation string - if len(args) >= 1 { - gitRemoteName = args[0] + cache, cacheReady := openCache(ctx, myLogger) + lfsFiles, usedCache, err := collectLfsFiles(ctx, cache, cacheReady, gitRemoteName, gitRemoteLocation, branches, refs, myLogger) + if err != nil { + myLogger.Error(fmt.Sprintf("error collecting LFS files: %v", err)) + return err + } + + myLogger.Debug(fmt.Sprintf("Preparing DRS objects for push branches: %v (cache=%v)", branches, usedCache)) + err = s.updateDrsObjects(builder, lfsFiles, drsmap.UpdateOptions{ + Cache: cache, + PreferCacheURL: usedCache, + Logger: myLogger, + }) + if err != nil { + myLogger.Error(fmt.Sprintf("UpdateDrsObjects failed: %v", err)) + return err + } + myLogger.Info("~~~~~~~~~~~~~ COMPLETED: pre-push ~~~~~~~~~~~~~") + return nil +} + +func parseRemoteArgs(args []string) (string, string) { + var gitRemoteName, gitRemoteLocation string + if len(args) >= 1 { + gitRemoteName = args[0] + } + if len(args) >= 2 { + gitRemoteLocation = args[1] + } + if gitRemoteName == "" { + gitRemoteName = "origin" + } + return gitRemoteName, gitRemoteLocation +} + +type pushedRef struct { + LocalRef string + LocalSHA string + RemoteRef string + RemoteSHA string +} + +func bufferStdin(stdin io.Reader, createTempFile func(dir, pattern string) (*os.File, error)) (*os.File, error) { + tmp, err := createTempFile("", "prepush-stdin-*") + if err != nil { + return nil, fmt.Errorf("error creating temp file for stdin: %w", err) + } + + if _, err := io.Copy(tmp, stdin); err != nil { + return nil, fmt.Errorf("error buffering stdin: %w", err) + } + + if _, err := tmp.Seek(0, 0); err != nil { + return nil, fmt.Errorf("error seeking temp stdin: %w", err) + } + return tmp, nil +} + +// readPushedBranches reads git push lines from the provided temp file, +// extracts unique local branch names for refs under `refs/heads/` and +// returns them sorted. The file is rewound to the start before returning. +func readPushedRefs(f io.ReadSeeker) ([]pushedRef, error) { + // Ensure we read from start + // example: + // refs/heads/main 67890abcdef1234567890abcdef1234567890abcd refs/heads/main 12345abcdef67890abcdef1234567890abcdef12 + if _, err := f.Seek(0, 0); err != nil { + return nil, err + } + scanner := bufio.NewScanner(f) + refs := make([]pushedRef, 0) + for scanner.Scan() { + line := scanner.Text() + fields := strings.Fields(line) + if len(fields) < 4 { + continue } - if len(args) >= 2 { - gitRemoteLocation = args[1] + refs = append(refs, pushedRef{ + LocalRef: fields[0], + LocalSHA: fields[1], + RemoteRef: fields[2], + RemoteSHA: fields[3], + }) + } + if err := scanner.Err(); err != nil { + return nil, err + } + // Rewind so caller can reuse the file + if _, err := f.Seek(0, 0); err != nil { + return nil, err + } + return refs, nil +} + +func branchesFromRefs(refs []pushedRef) []string { + const prefix = "refs/heads/" + set := make(map[string]struct{}) + for _, ref := range refs { + if strings.HasPrefix(ref.LocalRef, prefix) { + branch := strings.TrimPrefix(ref.LocalRef, prefix) + if branch != "" { + set[branch] = struct{}{} + } } - if gitRemoteName == "" { - gitRemoteName = "origin" + } + branches := make([]string, 0, len(set)) + for b := range set { + branches = append(branches, b) + } + sort.Strings(branches) + return branches +} + +func openCache(ctx context.Context, logger *slog.Logger) (*precommit_cache.Cache, bool) { + cache, err := precommit_cache.Open(ctx) + if err != nil { + logger.Debug(fmt.Sprintf("pre-commit cache unavailable: %v", err)) + return nil, false + } + if _, err := os.Stat(cache.Root); err != nil { + if os.IsNotExist(err) { + logger.Debug("pre-commit cache missing; continuing without cache") + } else { + logger.Debug(fmt.Sprintf("pre-commit cache access error: %v", err)) } - myLogger.Debug(fmt.Sprintf("git remote name: %s, git remote location: %s", gitRemoteName, gitRemoteLocation)) + return nil, false + } + return cache, true +} - // get the default remote from the .git/drs/config - var remote config.Remote - remote, err = cfg.GetDefaultRemote() +func collectLfsFiles(ctx context.Context, cache *precommit_cache.Cache, cacheReady bool, gitRemoteName, gitRemoteLocation string, branches []string, refs []pushedRef, logger *slog.Logger) (map[string]lfs.LfsFileInfo, bool, error) { + if cacheReady { + lfsFiles, ok, err := lfsFilesFromCache(ctx, cache, refs, logger) if err != nil { - myLogger.Debug(fmt.Sprintf("Warning. Error getting default remote: %v", err)) - // Print warning to stderr and return success (exit 0) - fmt.Fprintln(os.Stderr, "Warning. Skipping DRS preparation. Error getting default remote:", err) - return nil + logger.Debug(fmt.Sprintf("pre-commit cache read failed: %v", err)) + } else if ok { + return lfsFiles, true, nil } + logger.Debug("pre-commit cache incomplete or stale; falling back to LFS discovery") + } + lfsFiles, err := lfs.GetAllLfsFiles(gitRemoteName, gitRemoteLocation, branches, logger) + if err != nil { + return nil, false, err + } + return lfsFiles, false, nil +} - // get the remote client - cli, err := cfg.GetRemoteClient(remote, myLogger) +const cacheMaxAge = 24 * time.Hour + +func lfsFilesFromCache(ctx context.Context, cache *precommit_cache.Cache, refs []pushedRef, logger *slog.Logger) (map[string]lfs.LfsFileInfo, bool, error) { + if cache == nil { + return nil, false, nil + } + paths, err := listPushedPaths(ctx, refs) + if err != nil { + return nil, false, err + } + lfsFiles := make(map[string]lfs.LfsFileInfo, len(paths)) + for _, path := range paths { + entry, ok, err := cache.ReadPathEntry(path) if err != nil { - // Print warning to stderr and return success (exit 0) - fmt.Fprintln(os.Stderr, "Warning. Skipping DRS preparation. Error getting remote client:", err) - myLogger.Debug(fmt.Sprintf("Warning. Skipping DRS preparation. Error getting remote client: %v", err)) - // Check for GitDrsIdxdClient + return nil, false, err } - dc, ok := cli.(*indexd.GitDrsIdxdClient) - if !ok { - return fmt.Errorf("cli is not IndexdClient: %T", cli) + if !ok || entry.LFSOID == "" { + return nil, false, nil } - myLogger.Debug(fmt.Sprintf("Current server: %s", dc.Config.ProjectId)) - - // Buffer stdin to a temp file and invoke `git lfs pre-push ` with same args and stdin. - tmp, err := os.CreateTemp("", "prepush-stdin-*") + if entry.UpdatedAt == "" || precommit_cache.StaleAfter(entry.UpdatedAt, cacheMaxAge) { + return nil, false, nil + } + stat, err := os.Stat(path) if err != nil { - myLogger.Debug(fmt.Sprintf("error creating temp file for stdin: %v", err)) - return err + logger.Debug(fmt.Sprintf("cache path stat failed for %s: %v", path, err)) + return nil, false, nil } - defer func() { - _ = tmp.Close() - _ = os.Remove(tmp.Name()) - }() - - // Copy all of stdin into the temp file. - if _, err := io.Copy(tmp, os.Stdin); err != nil { - myLogger.Debug(fmt.Sprintf("error buffering stdin: %v", err)) - return err + lfsFiles[path] = lfs.LfsFileInfo{ + Name: path, + Size: stat.Size(), + OidType: "sha256", + Oid: entry.LFSOID, + Version: "https://git-lfs.github.com/spec/v1", } + } + return lfsFiles, true, nil +} - // Rewind to start so the child process can read it. - if _, err := tmp.Seek(0, 0); err != nil { - myLogger.Debug(fmt.Sprintf("error seeking temp stdin: %v", err)) - return err +func listPushedPaths(ctx context.Context, refs []pushedRef) ([]string, error) { + const zeroSHA = "0000000000000000000000000000000000000000" + set := make(map[string]struct{}) + for _, ref := range refs { + if ref.LocalSHA == "" || ref.LocalSHA == zeroSHA { + continue } - - // read the temp file and get a list of all unique local branches being pushed - branches, err := readPushedBranches(tmp) - if err != nil { - myLogger.Debug(fmt.Sprintf("error reading pushed branches: %v", err)) - return err + var args []string + if ref.RemoteSHA == "" || ref.RemoteSHA == zeroSHA { + args = []string{"ls-tree", "-r", "--name-only", ref.LocalSHA} + } else { + args = []string{"diff", "--name-only", ref.RemoteSHA, ref.LocalSHA} } - - myLogger.Debug(fmt.Sprintf("Preparing DRS objects for push branches: %v", branches)) - err = drsmap.UpdateDrsObjects(cli, gitRemoteName, gitRemoteLocation, branches, myLogger) + out, err := gitOutput(ctx, args...) if err != nil { - myLogger.Debug(fmt.Sprintf("UpdateDrsObjects failed: %v", err)) - return err + return nil, err + } + for _, line := range strings.Split(strings.TrimSpace(out), "\n") { + line = strings.TrimSpace(line) + if line == "" { + continue + } + set[line] = struct{}{} } - myLogger.Debug("DRS objects prepared for push!") + } + paths := make([]string, 0, len(set)) + for path := range set { + paths = append(paths, path) + } + sort.Strings(paths) + return paths, nil +} - myLogger.Debug("~~~~~~~~~~~~~ COMPLETED: pre-push ~~~~~~~~~~~~~") - return nil - }, +func gitOutput(ctx context.Context, args ...string) (string, error) { + cmd := exec.CommandContext(ctx, "git", args...) + cmd.Env = os.Environ() + out, err := cmd.Output() + if err != nil { + return "", fmt.Errorf("git %s: %s", strings.Join(args, " "), strings.TrimSpace(string(out))) + } + return string(out), nil } // readPushedBranches reads git push lines from the provided temp file, diff --git a/cmd/prepush/main_test.go b/cmd/prepush/main_test.go index f5d382d1..219c4c74 100644 --- a/cmd/prepush/main_test.go +++ b/cmd/prepush/main_test.go @@ -1,10 +1,19 @@ package prepush import ( + "context" + "encoding/json" + "io" + "log/slog" "os" + "os/exec" + "path/filepath" + "strings" "testing" + "time" "github.com/calypr/git-drs/internal/testutils" + "github.com/calypr/git-drs/precommit_cache" ) func TestPrepushCmd(t *testing.T) { @@ -15,6 +24,85 @@ func TestValidateArgs(t *testing.T) { testutils.RunCmdArgsTest(t) } +func TestLfsFilesFromCache(t *testing.T) { + repo := setupGitRepo(t) + filePath := filepath.Join(repo, "data", "file.bin") + if err := os.MkdirAll(filepath.Dir(filePath), 0o755); err != nil { + t.Fatalf("mkdir: %v", err) + } + if err := os.WriteFile(filePath, []byte("first"), 0o644); err != nil { + t.Fatalf("write: %v", err) + } + gitCmd(t, repo, "add", "data/file.bin") + gitCmd(t, repo, "commit", "-m", "first") + oldSHA := gitOutputString(t, repo, "rev-parse", "HEAD") + + if err := os.WriteFile(filePath, []byte("second"), 0o644); err != nil { + t.Fatalf("write: %v", err) + } + gitCmd(t, repo, "add", "data/file.bin") + gitCmd(t, repo, "commit", "-m", "second") + newSHA := gitOutputString(t, repo, "rev-parse", "HEAD") + + cacheRoot := filepath.Join(repo, ".git", "drs", "pre-commit", "v1") + cache := &precommit_cache.Cache{ + GitDir: filepath.Join(repo, ".git"), + Root: cacheRoot, + PathsDir: filepath.Join(cacheRoot, "paths"), + OIDsDir: filepath.Join(cacheRoot, "oids"), + StatePath: filepath.Join(cacheRoot, "state.json"), + } + if err := os.MkdirAll(cache.PathsDir, 0o755); err != nil { + t.Fatalf("mkdir paths dir: %v", err) + } + if err := os.MkdirAll(cache.OIDsDir, 0o755); err != nil { + t.Fatalf("mkdir oids dir: %v", err) + } + pathEntry := precommit_cache.PathEntry{ + Path: "data/file.bin", + LFSOID: "oid-123", + UpdatedAt: time.Now().UTC().Format(time.RFC3339), + } + pathEntryFile := filepath.Join(cache.PathsDir, precommit_cache.EncodePath(pathEntry.Path)+".json") + writeJSON(t, pathEntryFile, pathEntry) + + oldwd := mustChdir(t, repo) + t.Cleanup(func() { _ = os.Chdir(oldwd) }) + + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + refs := []pushedRef{{ + LocalRef: "refs/heads/main", + LocalSHA: newSHA, + RemoteRef: "refs/heads/main", + RemoteSHA: oldSHA, + }} + + lfsFiles, ok, err := lfsFilesFromCache(context.Background(), cache, refs, logger) + if err != nil { + t.Fatalf("lfsFilesFromCache: %v", err) + } + if !ok { + t.Fatalf("expected cache to be usable") + } + info, exists := lfsFiles["data/file.bin"] + if !exists { + t.Fatalf("expected lfs info for data/file.bin") + } + if info.Oid != "oid-123" { + t.Fatalf("expected oid to be oid-123, got %s", info.Oid) + } + if info.OidType != "sha256" { + t.Fatalf("expected oid type sha256, got %s", info.OidType) + } + stat, err := os.Stat(filePath) + if err != nil { + t.Fatalf("stat: %v", err) + } + if info.Size != stat.Size() { + t.Fatalf("expected size %d, got %d", stat.Size(), info.Size) + } +} + func TestReadPushedBranches(t *testing.T) { tests := []struct { name string @@ -81,3 +169,112 @@ func TestReadPushedBranches(t *testing.T) { }) } } + +func TestLfsFilesFromCacheStale(t *testing.T) { + repo := setupGitRepo(t) + filePath := filepath.Join(repo, "data", "file.bin") + if err := os.MkdirAll(filepath.Dir(filePath), 0o755); err != nil { + t.Fatalf("mkdir: %v", err) + } + if err := os.WriteFile(filePath, []byte("data"), 0o644); err != nil { + t.Fatalf("write: %v", err) + } + gitCmd(t, repo, "add", "data/file.bin") + gitCmd(t, repo, "commit", "-m", "first") + sha := gitOutputString(t, repo, "rev-parse", "HEAD") + + cacheRoot := filepath.Join(repo, ".git", "drs", "pre-commit", "v1") + cache := &precommit_cache.Cache{ + GitDir: filepath.Join(repo, ".git"), + Root: cacheRoot, + PathsDir: filepath.Join(cacheRoot, "paths"), + OIDsDir: filepath.Join(cacheRoot, "oids"), + StatePath: filepath.Join(cacheRoot, "state.json"), + } + if err := os.MkdirAll(cache.PathsDir, 0o755); err != nil { + t.Fatalf("mkdir paths dir: %v", err) + } + + pathEntry := precommit_cache.PathEntry{ + Path: "data/file.bin", + LFSOID: "oid-123", + UpdatedAt: time.Now().Add(-48 * time.Hour).UTC().Format(time.RFC3339), + } + pathEntryFile := filepath.Join(cache.PathsDir, precommit_cache.EncodePath(pathEntry.Path)+".json") + writeJSON(t, pathEntryFile, pathEntry) + + oldwd := mustChdir(t, repo) + t.Cleanup(func() { _ = os.Chdir(oldwd) }) + + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + refs := []pushedRef{{ + LocalRef: "refs/heads/main", + LocalSHA: sha, + RemoteRef: "refs/heads/main", + RemoteSHA: "0000000000000000000000000000000000000000", + }} + + _, ok, err := lfsFilesFromCache(context.Background(), cache, refs, logger) + if err != nil { + t.Fatalf("lfsFilesFromCache: %v", err) + } + if ok { + t.Fatalf("expected cache to be stale") + } +} + +func setupGitRepo(t *testing.T) string { + t.Helper() + dir := t.TempDir() + gitCmd(t, dir, "init") + gitCmd(t, dir, "config", "user.email", "test@example.com") + gitCmd(t, dir, "config", "user.name", "Test User") + return dir +} + +func gitCmd(t *testing.T, dir string, args ...string) { + t.Helper() + cmd := exec.Command("git", args...) + cmd.Dir = dir + out, err := cmd.CombinedOutput() + if err != nil { + t.Fatalf("git %s failed: %v (%s)", strings.Join(args, " "), err, string(out)) + } +} + +func gitOutputString(t *testing.T, dir string, args ...string) string { + t.Helper() + cmd := exec.Command("git", args...) + cmd.Dir = dir + out, err := cmd.Output() + if err != nil { + t.Fatalf("git %s failed: %v (%s)", strings.Join(args, " "), err, string(out)) + } + return strings.TrimSpace(string(out)) +} + +func writeJSON(t *testing.T, path string, value any) { + t.Helper() + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + t.Fatalf("mkdir: %v", err) + } + data, err := json.Marshal(value) + if err != nil { + t.Fatalf("marshal: %v", err) + } + if err := os.WriteFile(path, data, 0o644); err != nil { + t.Fatalf("write: %v", err) + } +} + +func mustChdir(t *testing.T, dir string) string { + t.Helper() + old, err := os.Getwd() + if err != nil { + t.Fatalf("Getwd: %v", err) + } + if err := os.Chdir(dir); err != nil { + t.Fatalf("Chdir(%s): %v", dir, err) + } + return old +} diff --git a/cmd/push/main.go b/cmd/push/main.go index 8cf54c73..5f571caa 100644 --- a/cmd/push/main.go +++ b/cmd/push/main.go @@ -7,8 +7,6 @@ import ( "github.com/calypr/git-drs/drslog" "github.com/calypr/git-drs/drsmap" "github.com/spf13/cobra" - - "github.com/calypr/git-drs/client/indexd" ) var Cmd = &cobra.Command{ @@ -47,13 +45,13 @@ var Cmd = &cobra.Command{ return err } - // Check for GitDrsIdxdClient - icli, ok := drsClient.(*indexd.GitDrsIdxdClient) - if !ok { - return fmt.Errorf("remote client is not an *indexdCl.IndexDClient (got %T), cannot push", drsClient) - } + //// Check for GitDrsIdxdClient + //icli, ok := drsClient.(*indexd.GitDrsIdxdClient) + //if !ok { + // return fmt.Errorf("remote client is not an *indexdCl.IndexDClient (got %T), cannot push", drsClient) + //} - err = drsmap.PushLocalDrsObjects(drsClient, icli.GetGen3Interface(), icli.GetBucketName(), icli.GetUpsert(), myLogger) + err = drsmap.PushLocalDrsObjects(drsClient, myLogger) if err != nil { return err } diff --git a/cmd/query/main.go b/cmd/query/main.go index e450a31e..2cf4f4a2 100644 --- a/cmd/query/main.go +++ b/cmd/query/main.go @@ -5,8 +5,8 @@ import ( "fmt" "github.com/bytedance/sonic" - "github.com/calypr/data-client/indexd/drs" - "github.com/calypr/data-client/indexd/hash" + "github.com/calypr/data-client/drs" + "github.com/calypr/data-client/hash" "github.com/calypr/git-drs/config" "github.com/calypr/git-drs/drslog" "github.com/spf13/cobra" diff --git a/cmd/query/main_test.go b/cmd/query/main_test.go index 8d7a4696..d98d8f92 100644 --- a/cmd/query/main_test.go +++ b/cmd/query/main_test.go @@ -4,8 +4,8 @@ import ( "context" "testing" - "github.com/calypr/data-client/indexd/drs" - "github.com/calypr/data-client/indexd/hash" + "github.com/calypr/data-client/drs" + "github.com/calypr/data-client/hash" ) type fakeChecksumClient struct { diff --git a/cmd/register/main.go b/cmd/register/main.go deleted file mode 100644 index a19db23a..00000000 --- a/cmd/register/main.go +++ /dev/null @@ -1,133 +0,0 @@ -package register - -import ( - "context" - "fmt" - - "github.com/calypr/data-client/indexd/hash" - "github.com/calypr/git-drs/client/indexd" - "github.com/calypr/git-drs/config" - "github.com/calypr/git-drs/drslog" - "github.com/calypr/git-drs/drsmap" - "github.com/calypr/git-drs/lfs" - "github.com/spf13/cobra" -) - -var remote string -var Cmd = &cobra.Command{ - Use: "register", - Short: "Register all pending DRS objects with indexd", - Long: "Reads pending objects from .git/drs/lfs/objects/ and registers them with indexd (does not upload files)", - Args: func(cmd *cobra.Command, args []string) error { - if len(args) != 0 { - cmd.SilenceUsage = false - return fmt.Errorf("error: accepts no arguments, received %d\n\nUsage: %s\n\nSee 'git drs register --help' for more details", len(args), cmd.UseLine()) - } - return nil - }, - RunE: func(cmd *cobra.Command, args []string) error { - logger, err := drslog.NewLogger("", true) - if err != nil { - return err - } - - cfg, err := config.LoadConfig() - if err != nil { - return err - } - - remoteName, err := cfg.GetRemoteOrDefault(remote) - if err != nil { - logger.Error(fmt.Sprintf("Error getting remote: %v", err)) - return err - } - - cli, err := cfg.GetRemoteClient(remoteName, logger) - if err != nil { - return fmt.Errorf("error creating indexd client: %v", err) - } - // Check for GitDrsIdxdClient - icli, ok := cli.(*indexd.GitDrsIdxdClient) - if !ok { - return fmt.Errorf("remote client is not an *indexd.GitDrsIdxdClient (got %T)", cli) - } - - // Get all pending objects - pendingObjects, err := lfs.GetPendingObjects(logger) - if err != nil { - return fmt.Errorf("error reading pending objects: %v", err) - } - - if len(pendingObjects) == 0 { - logger.Debug("No pending objects to register") - return nil - } - - logger.Debug(fmt.Sprintf("Found %d pending object(s) to register", len(pendingObjects))) - - registeredCount := 0 - skippedCount := 0 - errorCount := 0 - - // Register each pending object with indexd - for _, obj := range pendingObjects { - logger.Debug(fmt.Sprintf("Processing %s (OID: %s)", obj.Path, obj.OID)) - - // Read the IndexdRecord from disk - indexdObj, err := drsmap.DrsInfoFromOid(obj.OID) - if err != nil { - logger.Error(fmt.Sprintf("Error reading DRS object for %s: %v", obj.Path, err)) - errorCount++ - continue - } - - // Check if records with this hash already exist in indexd - records, err := cli.GetObjectByHash(context.Background(), &hash.Checksum{Type: "sha256", Checksum: obj.OID}) - if err != nil { - logger.Error(fmt.Sprintf("Error querying indexd for %s: %v", obj.Path, err)) - errorCount++ - continue - } - - // Check if a record with this exact DID already exists - alreadyExists := false - for _, record := range records { - if record.Id == indexdObj.Id { - alreadyExists = true - break - } - } - - if alreadyExists { - logger.Debug(fmt.Sprintf("Record for %s (DID: %s) already exists in indexd, skipping", obj.Path, indexdObj.Id)) - skippedCount++ - continue - } - - // Register the indexd record - _, err = icli.RegisterRecord(context.Background(), indexdObj) - if err != nil { - logger.Error(fmt.Sprintf("Error registering %s with indexd: %v", obj.Path, err)) - errorCount++ - continue - } - - logger.Debug(fmt.Sprintf("Successfully registered %s with DID %s", obj.Path, indexdObj.Id)) - registeredCount++ - } - - // Summary - logger.Debug(fmt.Sprintf("Registration complete: %d registered, %d skipped, %d errors", - registeredCount, skippedCount, errorCount)) - - if errorCount > 0 { - return fmt.Errorf("completed with %d error(s)", errorCount) - } - - return nil - }, -} - -func init() { - Cmd.Flags().StringVarP(&remote, "remote", "r", "", "target remote DRS server (default: default_remote)") -} diff --git a/cmd/register/main_test.go b/cmd/register/main_test.go deleted file mode 100644 index 51bcc658..00000000 --- a/cmd/register/main_test.go +++ /dev/null @@ -1,23 +0,0 @@ -package register - -import ( - "testing" - - "github.com/calypr/git-drs/internal/testutils" - "github.com/stretchr/testify/assert" -) - -func TestRegisterCmdArgs(t *testing.T) { - err := Cmd.Args(Cmd, []string{}) - assert.NoError(t, err) - - err = Cmd.Args(Cmd, []string{"extra"}) - assert.Error(t, err) -} - -func TestRegisterRun_Error(t *testing.T) { - _ = testutils.SetupTestGitRepo(t) - // No config - err := Cmd.RunE(Cmd, []string{}) - assert.Error(t, err) -} diff --git a/cmd/remote/add/gen3.go b/cmd/remote/add/gen3.go index 31e11624..13bb2816 100644 --- a/cmd/remote/add/gen3.go +++ b/cmd/remote/add/gen3.go @@ -9,9 +9,9 @@ import ( "github.com/calypr/data-client/g3client" "github.com/calypr/data-client/logs" "github.com/calypr/git-drs/client/indexd" + "github.com/calypr/git-drs/common" "github.com/calypr/git-drs/config" "github.com/calypr/git-drs/drslog" - "github.com/calypr/git-drs/utils" "github.com/spf13/cobra" ) @@ -64,7 +64,7 @@ func gen3Init(remoteName, credFile, fenceToken, project, bucket string, logg *sl case fenceToken != "": accessToken = fenceToken var err error - apiEndpoint, err = utils.ParseAPIEndpointFromToken(accessToken) + apiEndpoint, err = common.ParseAPIEndpointFromToken(accessToken) if err != nil { return fmt.Errorf("failed to parse API endpoint from provided access token: %w", err) } @@ -78,7 +78,7 @@ func gen3Init(remoteName, credFile, fenceToken, project, bucket string, logg *sl apiKey = cred.APIKey keyID = cred.KeyID - apiEndpoint, err = utils.ParseAPIEndpointFromToken(cred.APIKey) + apiEndpoint, err = common.ParseAPIEndpointFromToken(cred.APIKey) if err != nil { return fmt.Errorf("failed to parse API endpoint from API key in credentials file: %w", err) } diff --git a/cmd/root.go b/cmd/root.go index 5bcc71b2..f210240e 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -3,18 +3,14 @@ package cmd import ( "github.com/calypr/git-drs/cmd/addref" "github.com/calypr/git-drs/cmd/addurl" - "github.com/calypr/git-drs/cmd/cache" - "github.com/calypr/git-drs/cmd/delete" + deleteCmd "github.com/calypr/git-drs/cmd/delete" "github.com/calypr/git-drs/cmd/deleteproject" - "github.com/calypr/git-drs/cmd/download" "github.com/calypr/git-drs/cmd/fetch" "github.com/calypr/git-drs/cmd/initialize" - "github.com/calypr/git-drs/cmd/list" - "github.com/calypr/git-drs/cmd/listconfig" + "github.com/calypr/git-drs/cmd/precommit" "github.com/calypr/git-drs/cmd/prepush" "github.com/calypr/git-drs/cmd/push" "github.com/calypr/git-drs/cmd/query" - "github.com/calypr/git-drs/cmd/register" "github.com/calypr/git-drs/cmd/remote" "github.com/calypr/git-drs/cmd/transfer" "github.com/calypr/git-drs/cmd/transferref" @@ -27,31 +23,29 @@ var RootCmd = &cobra.Command{ Use: "git-drs", Short: "Git DRS - Git-LFS file management for DRS servers", Long: "Git DRS provides the benefits of Git-LFS file management using DRS for seamless integration with Gen3 servers", - PersistentPreRun: func(cmd *cobra.Command, args []string) { - //pre-run code can go here - }, } func init() { - RootCmd.AddCommand(addref.Cmd) - RootCmd.AddCommand(cache.Cmd) - RootCmd.AddCommand(delete.Cmd) - RootCmd.AddCommand(deleteproject.Cmd) - RootCmd.AddCommand(register.Cmd) - RootCmd.AddCommand(download.Cmd) + // Hide internal commands + precommit.Cmd.Hidden = true + prepush.Cmd.Hidden = true + transfer.Cmd.Hidden = true + transferref.Cmd.Hidden = true + RootCmd.AddCommand(initialize.Cmd) - RootCmd.AddCommand(list.Cmd) - RootCmd.AddCommand(list.ListProjectCmd) - RootCmd.AddCommand(listconfig.Cmd) - RootCmd.AddCommand(prepush.Cmd) - RootCmd.AddCommand(query.Cmd) - RootCmd.AddCommand(transfer.Cmd) - RootCmd.AddCommand(transferref.Cmd) RootCmd.AddCommand(version.Cmd) - RootCmd.AddCommand(addurl.AddURLCmd) RootCmd.AddCommand(remote.Cmd) RootCmd.AddCommand(fetch.Cmd) RootCmd.AddCommand(push.Cmd) + RootCmd.AddCommand(precommit.Cmd) + RootCmd.AddCommand(prepush.Cmd) + RootCmd.AddCommand(transfer.Cmd) + RootCmd.AddCommand(transferref.Cmd) + RootCmd.AddCommand(addref.Cmd) + RootCmd.AddCommand(addurl.Cmd) + RootCmd.AddCommand(deleteCmd.Cmd) + RootCmd.AddCommand(deleteproject.Cmd) + RootCmd.AddCommand(query.Cmd) RootCmd.CompletionOptions.HiddenDefaultCmd = true RootCmd.SilenceUsage = true diff --git a/cmd/transfer/main.go b/cmd/transfer/main.go index 71706dee..7045426e 100644 --- a/cmd/transfer/main.go +++ b/cmd/transfer/main.go @@ -12,7 +12,7 @@ import ( "github.com/calypr/git-drs/common" "github.com/calypr/data-client/download" - "github.com/calypr/data-client/indexd/hash" + "github.com/calypr/data-client/hash" "github.com/calypr/git-drs/client" "github.com/calypr/git-drs/config" "github.com/calypr/git-drs/drslog" @@ -162,6 +162,18 @@ var Cmd = &cobra.Command{ errMsg := fmt.Sprintf("Error finding matching record for project %s: %v", drsClient.GetProjectId(), err) logger.ErrorContext(ctx, errMsg) lfs.WriteErrorMessage(streamEncoder, downloadMsg.Oid, 502, errMsg) + errMsg = fmt.Sprintf("Error getting signed URL for OID %s: %v", downloadMsg.Oid, err) + logger.Error(errMsg) + + drsObject, errG := drsmap.DrsInfoFromOid(downloadMsg.Oid) + if errG == nil && drsObject != nil { + manualDownloadMsg := fmt.Sprintf("%s %s", drsObject.AccessMethods[0].AccessURL.URL, drsObject.Name) + logger.Info(manualDownloadMsg) + lfs.WriteErrorMessage(streamEncoder, downloadMsg.Oid, 302, manualDownloadMsg) + } else { + logger.Error(fmt.Sprintf("drsClient.GetObject failed for %s: %v ", downloadMsg.Oid, errG)) + lfs.WriteErrorMessage(streamEncoder, downloadMsg.Oid, 502, errMsg) + } continue } if matchingRecord == nil { diff --git a/cmd/transferref/main.go b/cmd/transferref/main.go index c57f284d..d0726386 100644 --- a/cmd/transferref/main.go +++ b/cmd/transferref/main.go @@ -255,7 +255,7 @@ func downloadFile(remote config.Remote, sha string) (string, error) { cmd.Stdout = logFile cmd.Stderr = logFile - cmdOut, err := cmd.CombinedOutput() + cmdOut, err := cmd.Output() if err != nil { return "", fmt.Errorf("error running drs_downloader for sha %s: %s", sha, cmdOut) } diff --git a/common/common.go b/common/common.go index 8df37afc..7a23c70a 100644 --- a/common/common.go +++ b/common/common.go @@ -1,5 +1,10 @@ package common +import ( + "fmt" + "strings" +) + // AddUnique appends items from 'toAdd' to 'existing' only if they're not already present. // Returns the updated slice with unique items. func AddUnique[T comparable](existing []T, toAdd []T) []T { @@ -21,3 +26,11 @@ func AddUnique[T comparable](existing []T, toAdd []T) []T { } return existing } + +func ProjectToResource(project string) (string, error) { + if !strings.Contains(project, "-") { + return "", fmt.Errorf("error: invalid project ID %s in config file, ID should look like -", project) + } + projectIdArr := strings.SplitN(project, "-", 2) + return "/programs/" + projectIdArr[0] + "/projects/" + projectIdArr[1], nil +} diff --git a/common/jwt.go b/common/jwt.go new file mode 100644 index 00000000..f94f69ce --- /dev/null +++ b/common/jwt.go @@ -0,0 +1,46 @@ +package common + +import ( + "fmt" + "net/url" + + "github.com/golang-jwt/jwt/v5" +) + +func ParseEmailFromToken(tokenString string) (string, error) { + claims := jwt.MapClaims{} + _, _, err := jwt.NewParser().ParseUnverified(tokenString, &claims) + if err != nil { + return "", fmt.Errorf("failed to decode token in ParseEmailFromToken: '%s': %w", tokenString, err) + } + context, ok := claims["context"].(map[string]any) + if !ok { + return "", fmt.Errorf("missing or invalid 'context' claim structure") + } + user, ok := context["user"].(map[string]any) + if !ok { + return "", fmt.Errorf("missing or invalid 'context.user' claim structure") + } + name, ok := user["name"].(string) + if !ok { + return "", fmt.Errorf("missing or invalid 'context.user.name' claim") + } + return name, nil +} + +func ParseAPIEndpointFromToken(tokenString string) (string, error) { + claims := jwt.MapClaims{} + _, _, err := jwt.NewParser().ParseUnverified(tokenString, &claims) + if err != nil { + return "", fmt.Errorf("failed to decode token in ParseAPIEndpointFromToken: '%s': %w", tokenString, err) + } + issUrl, ok := claims["iss"].(string) + if !ok { + return "", fmt.Errorf("missing or invalid 'iss' claim") + } + parsedURL, err := url.Parse(issUrl) + if err != nil { + return "", err + } + return fmt.Sprintf("%s://%s", parsedURL.Scheme, parsedURL.Host), nil +} diff --git a/common/jwt_test.go b/common/jwt_test.go new file mode 100644 index 00000000..556decc3 --- /dev/null +++ b/common/jwt_test.go @@ -0,0 +1,125 @@ +package common + +import ( + "testing" + + "github.com/golang-jwt/jwt/v5" +) + +func TestParseEmailFromToken(t *testing.T) { + claims := jwt.MapClaims{ + "context": map[string]any{ + "user": map[string]any{ + "name": "user@example.com", + }, + }, + } + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + tokenString, err := token.SignedString([]byte("secret")) + if err != nil { + t.Fatalf("sign token: %v", err) + } + + email, err := ParseEmailFromToken(tokenString) + if err != nil { + t.Fatalf("ParseEmailFromToken error: %v", err) + } + if email != "user@example.com" { + t.Fatalf("expected user@example.com, got %s", email) + } +} + +func TestParseEmailFromTokenErrors(t *testing.T) { + t.Run("invalid token", func(t *testing.T) { + if _, err := ParseEmailFromToken("not-a-token"); err == nil { + t.Fatalf("expected error for invalid token") + } + }) + + t.Run("missing context", func(t *testing.T) { + token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{}) + tokenString, err := token.SignedString([]byte("secret")) + if err != nil { + t.Fatalf("sign token: %v", err) + } + if _, err := ParseEmailFromToken(tokenString); err == nil { + t.Fatalf("expected error for missing context") + } + }) + + t.Run("missing user", func(t *testing.T) { + claims := jwt.MapClaims{ + "context": map[string]any{}, + } + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + tokenString, err := token.SignedString([]byte("secret")) + if err != nil { + t.Fatalf("sign token: %v", err) + } + if _, err := ParseEmailFromToken(tokenString); err == nil { + t.Fatalf("expected error for missing user") + } + }) + + t.Run("missing name", func(t *testing.T) { + claims := jwt.MapClaims{ + "context": map[string]any{ + "user": map[string]any{}, + }, + } + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + tokenString, err := token.SignedString([]byte("secret")) + if err != nil { + t.Fatalf("sign token: %v", err) + } + if _, err := ParseEmailFromToken(tokenString); err == nil { + t.Fatalf("expected error for missing name") + } + }) +} + +func TestParseAPIEndpointFromToken(t *testing.T) { + claims := jwt.MapClaims{ + "iss": "https://api.example.com/auth", + } + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + tokenString, err := token.SignedString([]byte("secret")) + if err != nil { + t.Fatalf("sign token: %v", err) + } + + endpoint, err := ParseAPIEndpointFromToken(tokenString) + if err != nil { + t.Fatalf("ParseAPIEndpointFromToken error: %v", err) + } + if endpoint != "https://api.example.com" { + t.Fatalf("expected https://api.example.com, got %s", endpoint) + } +} + +func TestParseAPIEndpointFromTokenErrors(t *testing.T) { + t.Run("missing iss", func(t *testing.T) { + token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{}) + tokenString, err := token.SignedString([]byte("secret")) + if err != nil { + t.Fatalf("sign token: %v", err) + } + if _, err := ParseAPIEndpointFromToken(tokenString); err == nil { + t.Fatalf("expected error for missing iss") + } + }) + + t.Run("invalid url", func(t *testing.T) { + claims := jwt.MapClaims{ + "iss": "://missing", + } + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + tokenString, err := token.SignedString([]byte("secret")) + if err != nil { + t.Fatalf("sign token: %v", err) + } + if _, err := ParseAPIEndpointFromToken(tokenString); err == nil { + t.Fatalf("expected error for invalid url") + } + }) +} diff --git a/config/config.go b/config/config.go index 27082457..3dfb354c 100644 --- a/config/config.go +++ b/config/config.go @@ -2,13 +2,16 @@ package config import ( "fmt" + "log" "log/slog" + "path/filepath" "strings" "github.com/calypr/data-client/g3client" "github.com/calypr/git-drs/client" anvil_client "github.com/calypr/git-drs/client/anvil" "github.com/calypr/git-drs/client/indexd" + "github.com/calypr/git-drs/common" "github.com/calypr/git-drs/gitrepo" "github.com/go-git/go-git/v5" ) @@ -22,6 +25,13 @@ const ( Gen3ServerType RemoteType = "gen3" AnvilServerType RemoteType = "anvil" + + newConfigSection = "lfs" + newConfigSubsectionRoot = "customtransfer.drs" + legacyConfigSection = "drs" + remoteSubsectionPrefix = "remote." + legacyDefaultRemoteKey = "drs.default-remote" + namespacedDefaultRemoteKey = "lfs.customtransfer.drs.default-remote" ) func AllRemoteTypes() []RemoteType { @@ -94,8 +104,10 @@ func (c Config) GetDefaultRemote() (Remote, error) { return "", fmt.Errorf( "no default remote configured.\n"+ "Set one with: git drs remote set \n"+ - "Available remotes: %v", + "Available remotes: %v\n"+ + "Config: %v\n", c.listRemoteNames(), + c, ) } @@ -133,6 +145,15 @@ func getRepo() (*git.Repository, error) { return gitrepo.GetRepo() } +func (c Config) ConfigPath() (string, error) { + return getConfigPath() +} + +// updates and git adds a Git DRS config file +// this should handle three cases: +// 1. create a new config file if it does not exist / is empty +// 2. return an error if the config file is invalid +// 3. update the existing config file, making sure to combine the new serversMap with the existing one // UpdateRemote updates and saves configuration using go-git func UpdateRemote(name Remote, remote RemoteSelect) (*Config, error) { repo, err := getRepo() @@ -145,22 +166,24 @@ func UpdateRemote(name Remote, remote RemoteSelect) (*Config, error) { return nil, err } - // Update drs.remote. subsection - subsection := fmt.Sprintf("remote.%s", name) + // Update lfs.customtransfer.drs.remote. subsection + remoteSubsectionName := fmt.Sprintf("%s.%s%s", newConfigSubsectionRoot, remoteSubsectionPrefix, name) + remoteSubsection := conf.Raw.Section(newConfigSection).Subsection(remoteSubsectionName) if remote.Gen3 != nil { - conf.Raw.Section("drs").Subsection(subsection).SetOption("type", "gen3") - conf.Raw.Section("drs").Subsection(subsection).SetOption("endpoint", remote.Gen3.Endpoint) - conf.Raw.Section("drs").Subsection(subsection).SetOption("project", remote.Gen3.ProjectID) - conf.Raw.Section("drs").Subsection(subsection).SetOption("bucket", remote.Gen3.Bucket) + remoteSubsection.SetOption("type", "gen3") + remoteSubsection.SetOption("endpoint", remote.Gen3.Endpoint) + remoteSubsection.SetOption("project", remote.Gen3.ProjectID) + remoteSubsection.SetOption("bucket", remote.Gen3.Bucket) } else if remote.Anvil != nil { - conf.Raw.Section("drs").Subsection(subsection).SetOption("type", "anvil") + remoteSubsection.SetOption("type", "anvil") } // Set default remote if not set - defaultRemote := conf.Raw.Section("drs").Option("default-remote") + configRoot := conf.Raw.Section(newConfigSection).Subsection(newConfigSubsectionRoot) + defaultRemote := configRoot.Option("default-remote") if defaultRemote == "" { - conf.Raw.Section("drs").SetOption("default-remote", string(name)) + configRoot.SetOption("default-remote", string(name)) } // Save config @@ -171,6 +194,27 @@ func UpdateRemote(name Remote, remote RemoteSelect) (*Config, error) { return LoadConfig() } +func parseAndAddRemote(cfg *Config, subsectionName string, remoteType string, endpoint string, project string, bucket string) { + if !strings.HasPrefix(subsectionName, remoteSubsectionPrefix) { + return + } + + remoteName := Remote(strings.TrimPrefix(subsectionName, remoteSubsectionPrefix)) + rs := RemoteSelect{} + + if remoteType == "gen3" || remoteType == "" { + rs.Gen3 = &indexd.Gen3Remote{ + Endpoint: endpoint, + ProjectID: project, + Bucket: bucket, + } + } else if remoteType == "anvil" { + rs.Anvil = &anvil_client.AnvilRemote{} + } + + cfg.Remotes[remoteName] = rs +} + // LoadConfig loads configuration using go-git func LoadConfig() (*Config, error) { repo, err := getRepo() @@ -187,30 +231,51 @@ func LoadConfig() (*Config, error) { Remotes: make(map[Remote]RemoteSelect), } - drsSection := conf.Raw.Section("drs") - cfg.DefaultRemote = Remote(drsSection.Option("default-remote")) + lfsSection := conf.Raw.Section(newConfigSection) + newRoot := lfsSection.Subsection(newConfigSubsectionRoot) + legacyRoot := conf.Raw.Section(legacyConfigSection) - for _, subsection := range drsSection.Subsections { - // Expect subsection name "remote." - parts := strings.SplitN(subsection.Name, ".", 2) - if len(parts) != 2 || parts[0] != "remote" { - continue + cfg.DefaultRemote = Remote(newRoot.Option("default-remote")) + if cfg.DefaultRemote == "" { + legacyDefault := legacyRoot.Option("default-remote") + if legacyDefault != "" { + log.Printf("Warning: git-drs config key '%s' is deprecated; use '%s'", legacyDefaultRemoteKey, namespacedDefaultRemoteKey) + cfg.DefaultRemote = Remote(legacyDefault) } - remoteName := Remote(parts[1]) - rs := RemoteSelect{} - - remoteType := subsection.Option("type") - if remoteType == "gen3" || remoteType == "" { // Default to gen3 for compatibility/inference - rs.Gen3 = &indexd.Gen3Remote{ - Endpoint: subsection.Option("endpoint"), - ProjectID: subsection.Option("project"), - Bucket: subsection.Option("bucket"), - } - } else if remoteType == "anvil" { - rs.Anvil = &anvil_client.AnvilRemote{} + } + + for _, subsection := range lfsSection.Subsections { + if !strings.HasPrefix(subsection.Name, newConfigSubsectionRoot+".") { + continue } + relativeName := strings.TrimPrefix(subsection.Name, newConfigSubsectionRoot+".") + parseAndAddRemote( + cfg, + relativeName, + subsection.Option("type"), + subsection.Option("endpoint"), + subsection.Option("project"), + subsection.Option("bucket"), + ) + } - cfg.Remotes[remoteName] = rs + for _, subsection := range legacyRoot.Subsections { + if !strings.HasPrefix(subsection.Name, remoteSubsectionPrefix) { + continue + } + remoteName := Remote(strings.TrimPrefix(subsection.Name, remoteSubsectionPrefix)) + if _, exists := cfg.Remotes[remoteName]; exists { + continue + } + log.Printf("Warning: git-drs config key prefix 'drs.%s' is deprecated; use 'lfs.customtransfer.drs.%s'", subsection.Name, subsection.Name) + parseAndAddRemote( + cfg, + subsection.Name, + subsection.Option("type"), + subsection.Option("endpoint"), + subsection.Option("project"), + subsection.Option("bucket"), + ) } return cfg, nil @@ -249,7 +314,7 @@ func SaveConfig(cfg *Config) error { } if cfg.DefaultRemote != "" { - conf.Raw.Section("drs").SetOption("default-remote", string(cfg.DefaultRemote)) + conf.Raw.Section(newConfigSection).Subsection(newConfigSubsectionRoot).SetOption("default-remote", string(cfg.DefaultRemote)) } return repo.Storer.SetConfig(conf) @@ -257,3 +322,12 @@ func SaveConfig(cfg *Config) error { // GetGitConfigInt reads an integer value from git config // getGitConfigValue retrieves a value from git config by key +func getConfigPath() (string, error) { + topLevel, err := gitrepo.GitTopLevel() + if err != nil { + return "", err + } + + configPath := filepath.Join(topLevel, common.DRS_DIR, common.CONFIG_YAML) + return configPath, nil +} diff --git a/config/config_test.go b/config/config_test.go index 84c73e66..4e8d7c2d 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -199,3 +199,70 @@ func TestConfig_MultipleRemotes(t *testing.T) { t.Errorf("Expected 3 remotes, got %d", len(cfg.Remotes)) } } + +func TestLoadConfig_LegacyKeysRemainSupported(t *testing.T) { + tmpDir := setupTestRepo(t) + + commands := [][]string{ + {"config", "drs.default-remote", "legacy"}, + {"config", "drs.remote.legacy.type", "gen3"}, + {"config", "drs.remote.legacy.endpoint", "https://legacy.example"}, + {"config", "drs.remote.legacy.project", "legacy-proj"}, + {"config", "drs.remote.legacy.bucket", "legacy-bucket"}, + } + for _, args := range commands { + cmd := exec.Command("git", args...) + cmd.Dir = tmpDir + if out, err := cmd.CombinedOutput(); err != nil { + t.Fatalf("git %v failed: %v: %s", args, err, string(out)) + } + } + + cfg, err := LoadConfig() + if err != nil { + t.Fatalf("LoadConfig error: %v", err) + } + if cfg.DefaultRemote != Remote("legacy") { + t.Fatalf("expected legacy default remote, got %s", cfg.DefaultRemote) + } + legacy := cfg.Remotes[Remote("legacy")] + if legacy.Gen3 == nil || legacy.Gen3.Endpoint != "https://legacy.example" { + t.Fatalf("expected legacy gen3 remote loaded, got %#v", legacy) + } +} + +func TestLoadConfig_NamespacedKeysTakePrecedence(t *testing.T) { + tmpDir := setupTestRepo(t) + + commands := [][]string{ + {"config", "drs.default-remote", "legacy"}, + {"config", "drs.remote.legacy.type", "gen3"}, + {"config", "drs.remote.legacy.endpoint", "https://legacy.example"}, + {"config", "drs.remote.legacy.project", "legacy-proj"}, + {"config", "drs.remote.legacy.bucket", "legacy-bucket"}, + {"config", "lfs.customtransfer.drs.default-remote", "new"}, + {"config", "lfs.customtransfer.drs.remote.new.type", "gen3"}, + {"config", "lfs.customtransfer.drs.remote.new.endpoint", "https://new.example"}, + {"config", "lfs.customtransfer.drs.remote.new.project", "new-proj"}, + {"config", "lfs.customtransfer.drs.remote.new.bucket", "new-bucket"}, + } + for _, args := range commands { + cmd := exec.Command("git", args...) + cmd.Dir = tmpDir + if out, err := cmd.CombinedOutput(); err != nil { + t.Fatalf("git %v failed: %v: %s", args, err, string(out)) + } + } + + cfg, err := LoadConfig() + if err != nil { + t.Fatalf("LoadConfig error: %v", err) + } + if cfg.DefaultRemote != Remote("new") { + t.Fatalf("expected namespaced default remote, got %s", cfg.DefaultRemote) + } + newRemote := cfg.Remotes[Remote("new")] + if newRemote.Gen3 == nil || newRemote.Gen3.Endpoint != "https://new.example" { + t.Fatalf("expected namespaced gen3 remote loaded, got %#v", newRemote) + } +} diff --git a/coverage/combined.html b/coverage/combined.html index 5a536b43..a529cf7a 100644 --- a/coverage/combined.html +++ b/coverage/combined.html @@ -61,7 +61,7 @@ - + @@ -69,91 +69,105 @@ - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + + + + + + + + + + + + + + + @@ -188,8 +202,8 @@ "time" "github.com/bytedance/sonic" - drs "github.com/calypr/data-client/indexd/drs" - hash "github.com/calypr/data-client/indexd/hash" + drs "github.com/calypr/data-client/drs" + hash "github.com/calypr/data-client/hash" "golang.org/x/oauth2/google" ) @@ -337,17 +351,17 @@ awsConfig "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/credentials" "github.com/aws/aws-sdk-go-v2/service/s3" + "github.com/calypr/data-client/drs" "github.com/calypr/data-client/fence" + "github.com/calypr/data-client/hash" "github.com/calypr/data-client/indexd" - "github.com/calypr/data-client/indexd/drs" - "github.com/calypr/data-client/indexd/hash" + "github.com/calypr/data-client/s3utils" + "github.com/calypr/git-drs/cloud" "github.com/calypr/git-drs/common" "github.com/calypr/git-drs/drslog" "github.com/calypr/git-drs/drsmap" "github.com/calypr/git-drs/lfs" "github.com/calypr/git-drs/messages" - "github.com/calypr/git-drs/s3_utils" - "github.com/calypr/git-drs/utils" ) // getBucketDetails fetches bucket details from Gen3 using data-client. @@ -357,7 +371,7 @@ // FetchS3MetadataWithBucketDetails fetches S3 metadata given bucket details. func FetchS3MetadataWithBucketDetails(ctx context.Context, s3URL, awsAccessKey, awsSecretKey, region, endpoint string, bucketDetails *fence.S3Bucket, s3Client *s3.Client, logger *slog.Logger) (int64, string, error) { - bucket, key, err := utils.ParseS3URL(s3URL) + bucket, key, err := cloud.ParseS3URL(s3URL) if err != nil { return 0, "", fmt.Errorf("failed to parse S3 URL: %w", err) } @@ -426,7 +440,7 @@ } func (inc *GitDrsIdxdClient) fetchS3Metadata(ctx context.Context, s3URL, awsAccessKey, awsSecretKey, region, endpoint string, s3Client *s3.Client, httpClient *http.Client, logger *slog.Logger) (int64, string, error) { - bucket, _, err := utils.ParseS3URL(s3URL) + bucket, _, err := cloud.ParseS3URL(s3URL) if err != nil { return 0, "", fmt.Errorf("failed to parse S3 URL: %w", err) } @@ -470,7 +484,7 @@ // If no record exists, create one logger.Debug("creating new record") - _, relPath, _ := utils.ParseS3URL(url) + _, relPath, _ := cloud.ParseS3URL(url) drsObj, err := drs.BuildDrsObj(relPath, sha256, fileSize, uuid, inc.Config.BucketName, projectId) if err != nil { @@ -482,11 +496,11 @@ return inc.RegisterRecord(ctx, drsObj) } -func (inc *GitDrsIdxdClient) AddURL(s3URL, sha256, awsAccessKey, awsSecretKey, regionFlag, endpointFlag string, opts ...s3_utils.AddURLOption) (s3_utils.S3Meta, error) { +func (inc *GitDrsIdxdClient) AddURL(s3URL, sha256, awsAccessKey, awsSecretKey, regionFlag, endpointFlag string, opts ...cloud.AddURLOption) (s3utils.S3Meta, error) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - cfg := &s3_utils.AddURLConfig{} + cfg := &cloud.AddURLConfig{} for _, opt := range opts { opt(cfg) } @@ -495,27 +509,27 @@ inc.Logger = drslog.NewNoOpLogger() } - if err := s3_utils.ValidateInputs(s3URL, sha256); err != nil { - return s3_utils.S3Meta{}, err + if err := s3utils.ValidateInputs(s3URL, sha256); err != nil { + return s3utils.S3Meta{}, err } - _, relPath, err := utils.ParseS3URL(s3URL) + _, relPath, err := cloud.ParseS3URL(s3URL) if err != nil { - return s3_utils.S3Meta{}, fmt.Errorf("failed to parse S3 URL: %w", err) + return s3utils.S3Meta{}, fmt.Errorf("failed to parse S3 URL: %w", err) } - isLFS, err := lfs.IsLFSTracked(".gitattributes", relPath) + isLFS, err := lfs.IsLFSTracked(relPath) if err != nil { - return s3_utils.S3Meta{}, fmt.Errorf("unable to determine if file is tracked by LFS: %w", err) + return s3utils.S3Meta{}, fmt.Errorf("unable to determine if file is tracked by LFS: %w", err) } if !isLFS { - return s3_utils.S3Meta{}, fmt.Errorf("file is not tracked by LFS") + return s3utils.S3Meta{}, fmt.Errorf("file is not tracked by LFS") } inc.Logger.Debug("Fetching S3 metadata...") fileSize, modifiedDate, err := inc.fetchS3Metadata(ctx, s3URL, awsAccessKey, awsSecretKey, regionFlag, endpointFlag, cfg.S3Client, cfg.HttpClient, inc.Logger) if err != nil { - return s3_utils.S3Meta{}, fmt.Errorf("failed to fetch S3 metadata: %w", err) + return s3utils.S3Meta{}, fmt.Errorf("failed to fetch S3 metadata: %w", err) } inc.Logger.Debug(fmt.Sprintf("Fetched S3 metadata successfully: %d bytes, modified: %s", fileSize, modifiedDate)) @@ -523,20 +537,20 @@ inc.Logger.Debug("Processing indexd record...") drsObj, err := inc.upsertIndexdRecord(ctx, s3URL, sha256, fileSize, inc.Logger) if err != nil { - return s3_utils.S3Meta{}, fmt.Errorf("failed to create indexd record: %w", err) + return s3utils.S3Meta{}, fmt.Errorf("failed to create indexd record: %w", err) } drsObjPath, err := drsmap.GetObjectPath(common.DRS_OBJS_PATH, drsObj.Checksums.SHA256) if err != nil { - return s3_utils.S3Meta{}, err + return s3utils.S3Meta{}, err } if err := drsmap.WriteDrsObj(drsObj, sha256, drsObjPath); err != nil { - return s3_utils.S3Meta{}, err + return s3utils.S3Meta{}, err } inc.Logger.Debug("Indexd updated") - return s3_utils.S3Meta{ + return s3utils.S3Meta{ Size: fileSize, LastModified: modifiedDate, }, nil @@ -553,9 +567,9 @@ "github.com/calypr/data-client/common" "github.com/calypr/data-client/conf" + "github.com/calypr/data-client/drs" "github.com/calypr/data-client/g3client" - "github.com/calypr/data-client/indexd/drs" - "github.com/calypr/data-client/indexd/hash" + "github.com/calypr/data-client/hash" "github.com/calypr/data-client/logs" "github.com/calypr/git-drs/client" "github.com/calypr/git-drs/drsmap" @@ -577,18 +591,18 @@ Config *Config } -func NewGitDrsIdxdClient(profileConfig conf.Credential, remote Gen3Remote, logger *slog.Logger, opts ...g3client.Option) (client.DRSClient, error) { +func NewGitDrsIdxdClient(profileConfig conf.Credential, remote Gen3Remote, logger *slog.Logger, opts ...g3client.Option) (client.DRSClient, error) { baseUrl, err := url.Parse(profileConfig.APIEndpoint) if err != nil { return nil, err } - projectId := remote.GetProjectId() + projectId := remote.GetProjectId() if projectId == "" { return nil, fmt.Errorf("no gen3 project specified") } - bucketName := remote.GetBucketName() + bucketName := remote.GetBucketName() // Initialize data-client Gen3Interface with slog-adapted logger if needed, // or assume we use the one passed in if we update data-client to take slog. @@ -608,18 +622,18 @@ if enableDataClientLogs { logOpts = append(logOpts, logs.WithMessageFile()) - } else { + } else { logOpts = append(logOpts, logs.WithNoMessageFile()) } - dLogger, closer := logs.New(profileConfig.Profile, logOpts...) + dLogger, closer := logs.New(profileConfig.Profile, logOpts...) _ = closer // If no options provided, use defaults for GitDrsIdxdClient - if len(opts) == 0 { + if len(opts) == 0 { opts = append(opts, g3client.WithClients(g3client.IndexdClient, g3client.FenceClient, g3client.SowerClient)) } - g3 := g3client.NewGen3InterfaceFromCredential(&profileConfig, dLogger, opts...) + g3 := g3client.NewGen3InterfaceFromCredential(&profileConfig, dLogger, opts...) upsert := gitrepo.GetGitConfigBool("lfs.customtransfer.drs.upsert", false) multiPartThresholdInt := gitrepo.GetGitConfigInt("lfs.customtransfer.drs.multipart-threshold", 500) @@ -640,7 +654,7 @@ }, nil } -func (cl *GitDrsIdxdClient) GetProjectId() string { +func (cl *GitDrsIdxdClient) GetProjectId() string { return cl.Config.ProjectId } @@ -750,7 +764,7 @@ return cl.G3.Indexd().GetProjectSample(ctx, projectId, limit) } -func (c *GitDrsIdxdClient) RegisterRecord(ctx context.Context, record *drs.DRSObject) (*drs.DRSObject, error) { +func (c *GitDrsIdxdClient) RegisterRecord(ctx context.Context, record *drs.DRSObject) (*drs.DRSObject, error) { return c.G3.Indexd().RegisterRecord(ctx, record) } @@ -758,11 +772,11 @@ return c.G3.Indexd().UpdateRecord(ctx, updateInfo, did) } -func (c *GitDrsIdxdClient) BuildDrsObj(fileName string, checksum string, size int64, drsId string) (*drs.DRSObject, error) { +func (c *GitDrsIdxdClient) BuildDrsObj(fileName string, checksum string, size int64, drsId string) (*drs.DRSObject, error) { return drs.BuildDrsObj(fileName, checksum, size, drsId, c.Config.BucketName, c.Config.ProjectId) } -func (cl *GitDrsIdxdClient) GetGen3Interface() g3client.Gen3Interface { +func (cl *GitDrsIdxdClient) GetGen3Interface() g3client.Gen3Interface { return cl.G3 } @@ -794,7 +808,7 @@ Bucket string `yaml:"bucket"` } -func (s Gen3Remote) GetProjectId() string { +func (s Gen3Remote) GetProjectId() string { return s.ProjectID } @@ -802,23 +816,23 @@ return s.Endpoint } -func (s Gen3Remote) GetBucketName() string { +func (s Gen3Remote) GetBucketName() string { return s.Bucket } -func (s Gen3Remote) GetClient(remoteName string, logger *slog.Logger, opts ...g3client.Option) (client.DRSClient, error) { +func (s Gen3Remote) GetClient(remoteName string, logger *slog.Logger, opts ...g3client.Option) (client.DRSClient, error) { manager := conf.NewConfigure(logger) cred, err := manager.Load(remoteName) if err != nil { return nil, err } - gen3Logger := logs.NewGen3Logger(logger, "", remoteName) + gen3Logger := logs.NewGen3Logger(logger, "", remoteName) if err := g3client.EnsureValidCredential(context.Background(), cred, manager, gen3Logger, nil); err != nil { return nil, err } - return NewGitDrsIdxdClient(*cred, s, logger, opts...) + return NewGitDrsIdxdClient(*cred, s, logger, opts...) } @@ -831,7 +845,7 @@ "strings" "github.com/calypr/data-client/common" - "github.com/calypr/data-client/indexd/drs" + "github.com/calypr/data-client/drs" "github.com/calypr/data-client/upload" "github.com/calypr/git-drs/drsmap" ) @@ -839,7 +853,7 @@ // RegisterFile implements DRSClient.RegisterFile // It registers (or reuses) an indexd record for the oid, uploads the object if it // is not already available in the bucket, and returns the resulting DRS object. -func (cl *GitDrsIdxdClient) RegisterFile(ctx context.Context, oid string, path string) (*drs.DRSObject, error) { +func (cl *GitDrsIdxdClient) RegisterFile(ctx context.Context, oid string, path string) (*drs.DRSObject, error) { cl.Logger.DebugContext(ctx, fmt.Sprintf("register file started for oid: %s", oid)) // load the DRS object from oid created by prepush @@ -848,20 +862,20 @@ return nil, fmt.Errorf("error getting drs object for oid %s: %v", oid, err) } - cl.Logger.InfoContext(ctx, fmt.Sprintf("registering record for oid %s in indexd (did: %s)", oid, drsObject.Id)) + cl.Logger.InfoContext(ctx, fmt.Sprintf("registering record for oid %s in indexd (did: %s)", oid, drsObject.Id)) _, err = cl.RegisterRecord(ctx, drsObject) - if err != nil { + if err != nil { // handle "already exists" error ie upsert behavior - if strings.Contains(err.Error(), "already exists") { + if strings.Contains(err.Error(), "already exists") { if !cl.Config.Upsert { cl.Logger.DebugContext(ctx, fmt.Sprintf("indexd record already exists, proceeding for oid %s: did: %s err: %v", oid, drsObject.Id, err)) - } else { + } else { cl.Logger.DebugContext(ctx, fmt.Sprintf("indexd record already exists, deleting and re-adding for oid %s: did: %s", oid, drsObject.Id)) err = cl.DeleteRecord(ctx, oid) if err != nil { return nil, fmt.Errorf("error deleting existing indexd record oid %s: did: %s err: %v", oid, drsObject.Id, err) } - _, err = cl.RegisterRecord(ctx, drsObject) + _, err = cl.RegisterRecord(ctx, drsObject) if err != nil { return nil, fmt.Errorf("error re-saving indexd record after deletion: oid %s: did: %s err: %v", oid, drsObject.Id, err) } @@ -870,7 +884,7 @@ return nil, fmt.Errorf("error saving oid %s indexd record: %v", oid, err) } } - cl.Logger.InfoContext(ctx, fmt.Sprintf("indexd record registration complete for oid %s", oid)) + cl.Logger.InfoContext(ctx, fmt.Sprintf("indexd record registration complete for oid %s", oid)) // Now attempt to upload the file if not already available cl.Logger.InfoContext(ctx, fmt.Sprintf("checking if oid %s is already downloadable", oid)) @@ -878,11 +892,11 @@ if err != nil { return nil, fmt.Errorf("error checking if file is downloadable: oid %s %v", oid, err) } - if downloadable { + if downloadable { cl.Logger.DebugContext(ctx, fmt.Sprintf("file %s is already available for download, skipping upload", oid)) return drsObject, nil } - cl.Logger.InfoContext(ctx, fmt.Sprintf("file %s is not downloadable, proceeding to upload", oid)) + cl.Logger.InfoContext(ctx, fmt.Sprintf("file %s is not downloadable, proceeding to upload", oid)) // Proceed to upload the file // Reuse the Gen3 interface @@ -893,7 +907,7 @@ if err != nil { return nil, fmt.Errorf("error opening file %s: %v", filePath, err) } - defer func(file *os.File) { + defer func(file *os.File) { err := file.Close() if err != nil { cl.Logger.DebugContext(ctx, fmt.Sprintf("warning: error closing file %s: %v", filePath, err)) @@ -901,12 +915,12 @@ }(file) // Use multipart threshold from config or default to 5GB - multiPartThreshold := int64(5 * 1024 * 1024 * 1024) // 5GB default - if cl.Config.MultiPartThreshold > 0 { + multiPartThreshold := int64(5 * 1024 * 1024 * 1024) // 5GB default + if cl.Config.MultiPartThreshold > 0 { multiPartThreshold = cl.Config.MultiPartThreshold } - if drsObject.Size < multiPartThreshold { + if drsObject.Size < multiPartThreshold { cl.Logger.DebugContext(ctx, fmt.Sprintf("UploadSingle size: %d path: %s", drsObject.Size, filePath)) req := common.FileUploadRequestObject{ SourcePath: filePath, @@ -936,23 +950,23 @@ return nil, fmt.Errorf("MultipartUpload error: %s", err) } } - return drsObject, nil + return drsObject, nil } // isFileDownloadable checks if a file is already available for download -func (cl *GitDrsIdxdClient) isFileDownloadable(ctx context.Context, drsObject *drs.DRSObject) (bool, error) { +func (cl *GitDrsIdxdClient) isFileDownloadable(ctx context.Context, drsObject *drs.DRSObject) (bool, error) { // Try to get a download URL - if successful, file is downloadable if len(drsObject.AccessMethods) == 0 { return false, nil } - accessType := drsObject.AccessMethods[0].Type + accessType := drsObject.AccessMethods[0].Type res, err := cl.G3.Indexd().GetDownloadURL(ctx, drsObject.Id, accessType) if err != nil { // If we can't get a download URL, assume file is not downloadable return false, nil } // Check if the URL is accessible - err = common.CanDownloadFile(res.URL) + err = common.CanDownloadFile(res.URL) return err == nil, nil } @@ -994,1283 +1008,2997 @@ } -