diff --git a/.github/workflows/pr-checks.yaml b/.github/workflows/pr-checks.yaml index a9882cbb..94190b0c 100644 --- a/.github/workflows/pr-checks.yaml +++ b/.github/workflows/pr-checks.yaml @@ -71,4 +71,4 @@ jobs: go-version-file: go.mod - name: Run tests - run: go test -v -race $(go list ./... | grep -v 'tests/integration/calypr' | grep -v 'client/indexd/tests') + run: go test -v -race $(go list ./... | grep -v '/cmd/addurl$') diff --git a/client/indexd/add_url.go b/client/indexd/add_url.go index 5828fe96..6debf002 100644 --- a/client/indexd/add_url.go +++ b/client/indexd/add_url.go @@ -15,19 +15,19 @@ import ( awsConfig "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/credentials" "github.com/aws/aws-sdk-go-v2/service/s3" + "github.com/calypr/git-drs/cloud" "github.com/calypr/git-drs/drs" "github.com/calypr/git-drs/drs/hash" "github.com/calypr/git-drs/drslog" "github.com/calypr/git-drs/drsmap" "github.com/calypr/git-drs/messages" "github.com/calypr/git-drs/projectdir" - "github.com/calypr/git-drs/s3_utils" "github.com/calypr/git-drs/utils" ) // getBucketDetails fetches bucket details from Gen3, loading config and auth. // This is the production version that includes all config/auth dependencies. -func (inc *IndexDClient) getBucketDetails(ctx context.Context, bucket string, httpClient *http.Client) (*s3_utils.S3Bucket, error) { +func (inc *IndexDClient) getBucketDetails(ctx context.Context, bucket string, httpClient *http.Client) (*cloud.S3Bucket, error) { // get all buckets baseURL := *inc.Base // Create a copy to avoid mutating inc.Base baseURL.Path = filepath.Join(baseURL.Path, "user/data/buckets") @@ -37,7 +37,7 @@ func (inc *IndexDClient) getBucketDetails(ctx context.Context, bucket string, ht // FetchS3MetadataWithBucketDetails fetches S3 metadata given bucket details. // This is the core testable logic, separated for easier unit testing. -func FetchS3MetadataWithBucketDetails(ctx context.Context, s3URL, awsAccessKey, awsSecretKey, region, endpoint string, bucketDetails *s3_utils.S3Bucket, s3Client *s3.Client, logger *slog.Logger) (int64, string, error) { +func FetchS3MetadataWithBucketDetails(ctx context.Context, s3URL, awsAccessKey, awsSecretKey, region, endpoint string, bucketDetails *cloud.S3Bucket, s3Client *s3.Client, logger *slog.Logger) (int64, string, error) { // Parse S3 URL bucket, key, err := utils.ParseS3URL(s3URL) @@ -148,8 +148,8 @@ func FetchS3MetadataWithBucketDetails(ctx context.Context, s3URL, awsAccessKey, errorMsg.WriteString(fmt.Sprintf(" %d. %s\n", i+1, field)) } errorMsg.WriteString("\nPlease provide these values via:\n") - errorMsg.WriteString(" - Command-line flags (--" + s3_utils.AWS_KEY_FLAG_NAME + ", --" + s3_utils.AWS_SECRET_FLAG_NAME + ", --" + s3_utils.AWS_REGION_FLAG_NAME + ", --" + s3_utils.AWS_ENDPOINT_URL_FLAG_NAME + ")\n") - errorMsg.WriteString(" - Environment variables (" + s3_utils.AWS_KEY_ENV_VAR + ", " + s3_utils.AWS_SECRET_ENV_VAR + ", " + s3_utils.AWS_REGION_ENV_VAR + ", " + s3_utils.AWS_ENDPOINT_URL_ENV_VAR + ")\n") + errorMsg.WriteString(" - Command-line flags (--" + cloud.AWS_KEY_FLAG_NAME + ", --" + cloud.AWS_SECRET_FLAG_NAME + ", --" + cloud.AWS_REGION_FLAG_NAME + ", --" + cloud.AWS_ENDPOINT_URL_FLAG_NAME + ")\n") + errorMsg.WriteString(" - Environment variables (" + cloud.AWS_KEY_ENV_VAR + ", " + cloud.AWS_SECRET_ENV_VAR + ", " + cloud.AWS_REGION_ENV_VAR + ", " + cloud.AWS_ENDPOINT_URL_ENV_VAR + ")\n") errorMsg.WriteString(" - AWS credentials file (~/.aws/credentials)\n") errorMsg.WriteString(" - Gen3 bucket registration (if bucket can be registered in Gen3)\n") errorMsg.WriteString("\n") @@ -199,7 +199,7 @@ func (inc *IndexDClient) fetchS3Metadata(ctx context.Context, s3URL, awsAccessKe } if bucketDetails == nil { logger.Debug("WARNING: no matching bucket found in CALYPR") - bucketDetails = &s3_utils.S3Bucket{} + bucketDetails = &cloud.S3Bucket{} } return FetchS3MetadataWithBucketDetails(ctx, s3URL, awsAccessKey, awsSecretKey, region, endpoint, bucketDetails, s3Client, logger) @@ -279,13 +279,13 @@ func (inc *IndexDClient) upsertIndexdRecord(url string, sha256 string, fileSize } // AddURL adds a file to the Git DRS repo using an S3 URL -func (inc *IndexDClient) AddURL(s3URL, sha256, awsAccessKey, awsSecretKey, regionFlag, endpointFlag string, opts ...s3_utils.AddURLOption) (s3_utils.S3Meta, error) { +func (inc *IndexDClient) AddURL(s3URL, sha256, awsAccessKey, awsSecretKey, regionFlag, endpointFlag string, opts ...cloud.AddURLOption) (cloud.S3Meta, error) { // Create context with 10-second timeout ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() // Apply options - cfg := &s3_utils.AddURLConfig{} + cfg := &cloud.AddURLConfig{} for _, opt := range opts { opt(cfg) } @@ -296,23 +296,23 @@ func (inc *IndexDClient) AddURL(s3URL, sha256, awsAccessKey, awsSecretKey, regio } // Validate inputs - if err := s3_utils.ValidateInputs(s3URL, sha256); err != nil { - return s3_utils.S3Meta{}, err + if err := cloud.ValidateInputs(s3URL, sha256); err != nil { + return cloud.S3Meta{}, err } // check that lfs is tracking the file _, relPath, err := utils.ParseS3URL(s3URL) if err != nil { - return s3_utils.S3Meta{}, fmt.Errorf("failed to parse S3 URL: %w", err) + return cloud.S3Meta{}, fmt.Errorf("failed to parse S3 URL: %w", err) } // confirm file is tracked isLFS, err := utils.IsLFSTracked(".gitattributes", relPath) if err != nil { - return s3_utils.S3Meta{}, fmt.Errorf("unable to determine if file is tracked by LFS: %w", err) + return cloud.S3Meta{}, fmt.Errorf("unable to determine if file is tracked by LFS: %w", err) } if !isLFS { - return s3_utils.S3Meta{}, fmt.Errorf("file is not tracked by LFS. Please run `git lfs track %s && git add .gitattributes` before proceeding", relPath) + return cloud.S3Meta{}, fmt.Errorf("file is not tracked by LFS. Please run `git lfs track %s && git add .gitattributes` before proceeding", relPath) } // Fetch S3 metadata (size, modified date) @@ -321,9 +321,9 @@ func (inc *IndexDClient) AddURL(s3URL, sha256, awsAccessKey, awsSecretKey, regio if err != nil { // if err contains 403, probably misconfigured credentials if strings.Contains(err.Error(), "403") { - return s3_utils.S3Meta{}, fmt.Errorf("failed to fetch S3 metadata: %w. Double check your configured AWS credentials and endpoint url", err) + return cloud.S3Meta{}, fmt.Errorf("failed to fetch S3 metadata: %w. Double check your configured AWS credentials and endpoint url", err) } - return s3_utils.S3Meta{}, fmt.Errorf("failed to fetch S3 metadata: %w", err) + return cloud.S3Meta{}, fmt.Errorf("failed to fetch S3 metadata: %w", err) } // logging @@ -335,21 +335,21 @@ func (inc *IndexDClient) AddURL(s3URL, sha256, awsAccessKey, awsSecretKey, regio inc.Logger.Debug("Processing indexd record...") drsObj, err := inc.upsertIndexdRecord(s3URL, sha256, fileSize, inc.Logger) if err != nil { - return s3_utils.S3Meta{}, fmt.Errorf("failed to create indexd record: %w", err) + return cloud.S3Meta{}, fmt.Errorf("failed to create indexd record: %w", err) } // write to file so push has that file available drsObjPath, err := drsmap.GetObjectPath(projectdir.DRS_OBJS_PATH, drsObj.Checksums.SHA256) if err != nil { - return s3_utils.S3Meta{}, fmt.Errorf("failed to get object path: %w", err) + return cloud.S3Meta{}, fmt.Errorf("failed to get object path: %w", err) } if err := drsmap.WriteDrsObj(drsObj, sha256, drsObjPath); err != nil { - return s3_utils.S3Meta{}, fmt.Errorf("failed to write DRS object: %w", err) + return cloud.S3Meta{}, fmt.Errorf("failed to write DRS object: %w", err) } inc.Logger.Debug("Indexd updated") - return s3_utils.S3Meta{ + return cloud.S3Meta{ Size: fileSize, LastModified: modifiedDate, }, nil diff --git a/client/indexd/indexd_client.go b/client/indexd/indexd_client.go index e0b4b6d6..0693907a 100644 --- a/client/indexd/indexd_client.go +++ b/client/indexd/indexd_client.go @@ -21,12 +21,12 @@ import ( "github.com/calypr/data-client/client/logs" "github.com/calypr/data-client/client/upload" "github.com/calypr/git-drs/client" + "github.com/calypr/git-drs/cloud" "github.com/calypr/git-drs/drs" "github.com/calypr/git-drs/drs/hash" "github.com/calypr/git-drs/drslog" "github.com/calypr/git-drs/drsmap" "github.com/calypr/git-drs/projectdir" - "github.com/calypr/git-drs/s3_utils" "github.com/calypr/git-drs/utils" "github.com/hashicorp/go-multierror" "github.com/hashicorp/go-retryablehttp" @@ -39,7 +39,7 @@ type IndexDClient struct { ProjectId string BucketName string Logger *slog.Logger - AuthHandler s3_utils.AuthHandler // Injected for testing/flexibility + AuthHandler cloud.AuthHandler // Injected for testing/flexibility HttpClient *retryablehttp.Client SConfig sonic.API @@ -1066,36 +1066,6 @@ func (cl *IndexDClient) GetIndexdRecordByDID(did string) (*OutputInfo, error) { return record, nil } -func (cl *IndexDClient) BuildDrsObj(fileName string, checksum string, size int64, drsId string) (*drs.DRSObject, error) { - bucket := cl.BucketName - if bucket == "" { - return nil, fmt.Errorf("error: bucket name is empty in config file") - } - - //TODO: support other storage backends - fileURL := fmt.Sprintf("s3://%s", filepath.Join(bucket, drsId, checksum)) - - authzStr, err := utils.ProjectToResource(cl.GetProjectId()) - if err != nil { - return nil, err - } - authorizations := drs.Authorizations{ - Value: authzStr, - } - - // create DrsObj - DrsObj := drs.DRSObject{ - Id: drsId, - Name: fileName, - // TODO: ensure that we can retrieve the access method during submission (happens in transfer) - AccessMethods: []drs.AccessMethod{{Type: "s3", AccessURL: drs.AccessURL{URL: fileURL}, Authorizations: &authorizations}}, - Checksums: hash.HashInfo{SHA256: checksum}, - Size: size, - } - - return &DrsObj, nil -} - // Helper function to get indexd record by DID (similar to existing pattern in DeleteIndexdRecord) func (cl *IndexDClient) getIndexdRecordByDID(did string) (*OutputInfo, error) { url := fmt.Sprintf("%s/index/%s", cl.Base.String(), did) diff --git a/client/indexd/indexd_client_test.go b/client/indexd/indexd_client_test.go index e5b2c044..49828286 100644 --- a/client/indexd/indexd_client_test.go +++ b/client/indexd/indexd_client_test.go @@ -8,7 +8,6 @@ import ( "net/url" "os" "os/exec" - "path/filepath" "strings" "sync" "testing" @@ -301,24 +300,6 @@ func TestIndexdClient_RegisterAndUpdate(t *testing.T) { } } -func TestIndexdClient_BuildDrsObj(t *testing.T) { - client := &IndexDClient{ - ProjectId: "test-project", - BucketName: "bucket", - } - - obj, err := client.BuildDrsObj("file.txt", "sha-256", 12, "did-1") - if err != nil { - t.Fatalf("BuildDrsObj error: %v", err) - } - if obj.Id != "did-1" || obj.Checksums.SHA256 != "sha-256" { - t.Fatalf("unexpected drs object: %+v", obj) - } - if len(obj.AccessMethods) != 1 || !strings.Contains(obj.AccessMethods[0].AccessURL.URL, filepath.Join("bucket", "did-1", "sha-256")) { - t.Fatalf("unexpected access URL: %+v", obj.AccessMethods) - } -} - func TestIndexdClient_GetProfile(t *testing.T) { client := &IndexDClient{AuthHandler: &RealAuthHandler{Cred: confCredential("profile")}} profile, err := client.GetProfile() @@ -464,48 +445,3 @@ func chdirForTest(t *testing.T, dir string) func() { } } } - -func TestBuildDrsObj_Success(t *testing.T) { - client := &IndexDClient{ - ProjectId: "test-project", - BucketName: "bucket", - } - - obj, err := client.BuildDrsObj("file.txt", "sha-256", 12, "did-1") - if err != nil { - t.Fatalf("BuildDrsObj error: %v", err) - } - if obj.Id != "did-1" { - t.Fatalf("unexpected Id: %s", obj.Id) - } - if obj.Name != "file.txt" { - t.Fatalf("unexpected Name: %s", obj.Name) - } - if obj.Checksums.SHA256 != "sha-256" { - t.Fatalf("unexpected checksum: %v", obj.Checksums) - } - if obj.Size != 12 { - t.Fatalf("unexpected size: %d", obj.Size) - } - if len(obj.AccessMethods) != 1 { - t.Fatalf("expected 1 access method, got %d", len(obj.AccessMethods)) - } - if !strings.Contains(obj.AccessMethods[0].AccessURL.URL, filepath.Join("bucket", "did-1", "sha-256")) { - t.Fatalf("unexpected access URL: %s", obj.AccessMethods[0].AccessURL.URL) - } - if obj.AccessMethods[0].Type != "s3" { - t.Fatalf("unexpected access method type: %s", obj.AccessMethods[0].Type) - } -} - -func TestBuildDrsObj_EmptyBucket(t *testing.T) { - client := &IndexDClient{ - ProjectId: "test-project", - BucketName: "", - } - - _, err := client.BuildDrsObj("file.txt", "sha-256", 12, "did-1") - if err == nil { - t.Fatalf("expected error when BucketName is empty") - } -} diff --git a/client/indexd/tests/add-url-unit_test.go b/client/indexd/tests/add-url-unit_test.go index 2a971805..6e4cbce2 100644 --- a/client/indexd/tests/add-url-unit_test.go +++ b/client/indexd/tests/add-url-unit_test.go @@ -16,10 +16,10 @@ import ( "github.com/bytedance/sonic" "github.com/bytedance/sonic/encoder" indexd_client "github.com/calypr/git-drs/client/indexd" + "github.com/calypr/git-drs/cloud" "github.com/calypr/git-drs/drs" "github.com/calypr/git-drs/drslog" "github.com/calypr/git-drs/drsmap" - "github.com/calypr/git-drs/s3_utils" ) // Unit Tests for validateInputs @@ -48,7 +48,7 @@ func TestValidateInputs_ConcurrentCalls(t *testing.T) { errChan := make(chan error, 10) for i := 0; i < 10; i++ { go func() { - errChan <- s3_utils.ValidateInputs(validS3URL, validSHA256) + errChan <- cloud.ValidateInputs(validS3URL, validSHA256) }() } @@ -74,8 +74,8 @@ func TestGetBucketDetailsWithAuth_Success(t *testing.T) { // Capture the auth header set by the handler authHeaderValue = r.Header.Get("Authorization") - response := s3_utils.S3BucketsResponse{ - S3Buckets: map[string]*s3_utils.S3Bucket{ + response := cloud.S3BucketsResponse{ + S3Buckets: map[string]*cloud.S3Bucket{ "test-bucket": { Region: "us-west-2", EndpointURL: "https://s3.amazonaws.com", @@ -116,8 +116,8 @@ func TestGetBucketDetailsWithAuth_Success(t *testing.T) { func TestGetBucketDetailsWithAuth_BucketMissing(t *testing.T) { // Test that missing bucket returns proper error server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - response := s3_utils.S3BucketsResponse{ - S3Buckets: map[string]*s3_utils.S3Bucket{ + response := cloud.S3BucketsResponse{ + S3Buckets: map[string]*cloud.S3Bucket{ "other-bucket": { Region: "us-east-1", EndpointURL: "https://s3.amazonaws.com", @@ -146,12 +146,12 @@ func TestGetBucketDetailsWithAuth_BucketMissing(t *testing.T) { func TestGetBucketDetailsWithAuth_MissingFields(t *testing.T) { tests := []struct { name string - bucket s3_utils.S3Bucket + bucket cloud.S3Bucket wantErrMsg string }{ { name: "missing region", - bucket: s3_utils.S3Bucket{ + bucket: cloud.S3Bucket{ EndpointURL: "https://s3.amazonaws.com", Region: "", }, @@ -159,7 +159,7 @@ func TestGetBucketDetailsWithAuth_MissingFields(t *testing.T) { }, { name: "missing endpoint", - bucket: s3_utils.S3Bucket{ + bucket: cloud.S3Bucket{ EndpointURL: "", Region: "us-west-2", }, @@ -167,7 +167,7 @@ func TestGetBucketDetailsWithAuth_MissingFields(t *testing.T) { }, { name: "missing both", - bucket: s3_utils.S3Bucket{ + bucket: cloud.S3Bucket{ EndpointURL: "", Region: "", }, @@ -178,8 +178,8 @@ func TestGetBucketDetailsWithAuth_MissingFields(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - response := s3_utils.S3BucketsResponse{ - S3Buckets: map[string]*s3_utils.S3Bucket{ + response := cloud.S3BucketsResponse{ + S3Buckets: map[string]*cloud.S3Bucket{ "test-bucket": &tt.bucket, }, } @@ -236,8 +236,8 @@ func TestGetBucketDetailsWithAuth_WithToken(t *testing.T) { tokenReceived = strings.TrimPrefix(authHeader, "Bearer ") } - response := s3_utils.S3BucketsResponse{ - S3Buckets: map[string]*s3_utils.S3Bucket{ + response := cloud.S3BucketsResponse{ + S3Buckets: map[string]*cloud.S3Bucket{ "test-bucket": { Region: "us-west-2", EndpointURL: "https://s3.amazonaws.com", @@ -274,8 +274,8 @@ func TestGetBucketDetailsWithAuth_NoAuthHandler(t *testing.T) { authHeaderPresent = true } - response := s3_utils.S3BucketsResponse{ - S3Buckets: map[string]*s3_utils.S3Bucket{ + response := cloud.S3BucketsResponse{ + S3Buckets: map[string]*cloud.S3Bucket{ "test-bucket": { Region: "us-west-2", EndpointURL: "https://s3.amazonaws.com", @@ -350,7 +350,7 @@ func TestS3BucketsResponse_UnmarshalValid(t *testing.T) { "GS_BUCKETS": {} }` - var response s3_utils.S3BucketsResponse + var response cloud.S3BucketsResponse err := sonic.ConfigFastest.Unmarshal([]byte(jsonData), &response) if err != nil { t.Fatalf("Failed to unmarshal S3BucketsResponse: %v", err) @@ -384,7 +384,7 @@ func TestS3BucketsResponse_EmptyBuckets(t *testing.T) { "GS_BUCKETS": {} }` - var response s3_utils.S3BucketsResponse + var response cloud.S3BucketsResponse err := sonic.ConfigFastest.Unmarshal([]byte(jsonData), &response) if err != nil { t.Fatalf("Failed to unmarshal empty S3BucketsResponse: %v", err) @@ -402,7 +402,7 @@ func TestS3Bucket_MissingOptionalFields(t *testing.T) { "region": "us-west-2" }` - var bucket s3_utils.S3Bucket + var bucket cloud.S3Bucket err := sonic.ConfigFastest.Unmarshal([]byte(jsonData), &bucket) if err != nil { t.Fatalf("Failed to unmarshal S3Bucket: %v", err) @@ -450,7 +450,7 @@ func TestFetchS3Metadata_Success_WithProvidedClient(t *testing.T) { }) // Provide bucket details directly (bypass getBucketDetails) - bucketDetails := &s3_utils.S3Bucket{ + bucketDetails := &cloud.S3Bucket{ Region: "us-west-2", EndpointURL: s3Mock.URL(), Programs: []string{"test-program"}, @@ -489,7 +489,7 @@ func TestFetchS3Metadata_Success_WithCredentialsInParams(t *testing.T) { s3Mock.AddObject("test-bucket", "file.bam", 1024) - bucketDetails := &s3_utils.S3Bucket{ + bucketDetails := &cloud.S3Bucket{ Region: "us-west-2", EndpointURL: s3Mock.URL(), } @@ -532,7 +532,7 @@ func TestFetchS3Metadata_Success_UsingBucketDetailsFromGen3(t *testing.T) { s3Mock.AddObject("test-bucket", "data.bam", 512) // Bucket details from Gen3 (simulated) - bucketDetails := &s3_utils.S3Bucket{ + bucketDetails := &cloud.S3Bucket{ Region: "us-west-2", EndpointURL: s3Mock.URL(), } @@ -568,7 +568,7 @@ func TestFetchS3Metadata_Failure_InvalidS3URL(t *testing.T) { ctx := context.Background() ignoreAWSConfigFiles(t) - bucketDetails := &s3_utils.S3Bucket{ + bucketDetails := &cloud.S3Bucket{ Region: "us-west-2", EndpointURL: "http://endpoint", } @@ -602,7 +602,7 @@ func TestFetchS3Metadata_Failure_MissingCredentials(t *testing.T) { s3Mock := NewMockS3Server(t) defer s3Mock.Close() - bucketDetails := &s3_utils.S3Bucket{ + bucketDetails := &cloud.S3Bucket{ Region: "", // No region - this will definitely trigger validation error EndpointURL: s3Mock.URL(), } @@ -636,7 +636,7 @@ func TestFetchS3Metadata_Failure_MissingRegion(t *testing.T) { // Bucket details WITHOUT region ignoreAWSConfigFiles(t) - bucketDetails := &s3_utils.S3Bucket{ + bucketDetails := &cloud.S3Bucket{ EndpointURL: "http://s3-endpoint", // No region field } @@ -686,7 +686,7 @@ func TestFetchS3Metadata_Failure_S3ObjectNotFound(t *testing.T) { o.UsePathStyle = true }) - bucketDetails := &s3_utils.S3Bucket{ + bucketDetails := &cloud.S3Bucket{ Region: "us-west-2", EndpointURL: s3Mock.URL(), } @@ -736,7 +736,7 @@ func TestFetchS3Metadata_Success_NilContentLength(t *testing.T) { o.UsePathStyle = true }) - bucketDetails := &s3_utils.S3Bucket{ + bucketDetails := &cloud.S3Bucket{ Region: "us-west-2", EndpointURL: s3Mock.URL(), } @@ -772,7 +772,7 @@ func TestFetchS3Metadata_Success_ParameterPriorityOverBucketDetails(t *testing.T defer s3Mock.Close() // Bucket details with DIFFERENT endpoint - bucketDetails := &s3_utils.S3Bucket{ + bucketDetails := &cloud.S3Bucket{ Region: "us-east-1", // Different region EndpointURL: "http://different-endpoint", Programs: []string{"test-program"}, @@ -1288,7 +1288,7 @@ func TestCustomEndpointResolver(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - resolver := &s3_utils.CustomEndpointResolver{Endpoint: tt.endpoint} + resolver := &cloud.CustomEndpointResolver{Endpoint: tt.endpoint} endpoint, err := resolver.ResolveEndpoint(tt.service, tt.region) if err != nil { @@ -1375,7 +1375,7 @@ func TestS3URLParsing_EdgeCases(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := s3_utils.ValidateInputs(tt.s3URL, "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855") + err := cloud.ValidateInputs(tt.s3URL, "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855") if (err != nil) != tt.expectError { t.Errorf("validateInputs() for %s error = %v, expectError %v", tt.description, err, tt.expectError) } diff --git a/client/indexd/utils.go b/client/indexd/utils.go index 57bb4124..78dd33aa 100644 --- a/client/indexd/utils.go +++ b/client/indexd/utils.go @@ -7,7 +7,7 @@ import ( "net/http" "github.com/bytedance/sonic" - "github.com/calypr/git-drs/s3_utils" + "github.com/calypr/git-drs/cloud" ) // getBucketDetailsWithAuth fetches bucket details from Gen3 using an AuthHandler. @@ -18,7 +18,7 @@ import ( // - bucketsEndpointURL: full URL to the /user/data/buckets endpoint // - authHandler: handler for adding authentication headers // - httpClient: the HTTP client to use -func GetBucketDetailsWithAuth(ctx context.Context, bucket, bucketsEndpointURL string, authHandler s3_utils.AuthHandler, httpClient *http.Client) (*s3_utils.S3Bucket, error) { +func GetBucketDetailsWithAuth(ctx context.Context, bucket, bucketsEndpointURL string, authHandler cloud.AuthHandler, httpClient *http.Client) (*cloud.S3Bucket, error) { // Use provided client or create default if httpClient == nil { httpClient = &http.Client{} @@ -47,7 +47,7 @@ func GetBucketDetailsWithAuth(ctx context.Context, bucket, bucketsEndpointURL st } // extract bucket endpoint - var bucketInfo s3_utils.S3BucketsResponse + var bucketInfo cloud.S3BucketsResponse if err := sonic.ConfigFastest.NewDecoder(resp.Body).Decode(&bucketInfo); err != nil { return nil, fmt.Errorf("failed to decode bucket information: %w", err) } diff --git a/client/indexd/utils_test.go b/client/indexd/utils_test.go index ad53bf4c..b7b51028 100644 --- a/client/indexd/utils_test.go +++ b/client/indexd/utils_test.go @@ -7,7 +7,7 @@ import ( "testing" "github.com/bytedance/sonic/encoder" - "github.com/calypr/git-drs/s3_utils" + "github.com/calypr/git-drs/cloud" ) type testAuthHandler struct { @@ -25,8 +25,8 @@ func TestGetBucketDetailsWithAuth(t *testing.T) { var authHeader string server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { authHeader = r.Header.Get("Authorization") - resp := s3_utils.S3BucketsResponse{ - S3Buckets: map[string]*s3_utils.S3Bucket{ + resp := cloud.S3BucketsResponse{ + S3Buckets: map[string]*cloud.S3Bucket{ "bucket": {Region: "us-east-1", EndpointURL: "https://s3.example.com"}, }, } @@ -49,8 +49,8 @@ func TestGetBucketDetailsWithAuth(t *testing.T) { func TestGetBucketDetailsWithAuth_NotFound(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - resp := s3_utils.S3BucketsResponse{ - S3Buckets: map[string]*s3_utils.S3Bucket{}, + resp := cloud.S3BucketsResponse{ + S3Buckets: map[string]*cloud.S3Bucket{}, } w.WriteHeader(http.StatusOK) _ = encoder.NewStreamEncoder(w).Encode(resp) diff --git a/client/interface.go b/client/interface.go index 5ee8f6d6..ba6ff5f7 100644 --- a/client/interface.go +++ b/client/interface.go @@ -2,9 +2,9 @@ package client import ( "github.com/calypr/data-client/client/common" + "github.com/calypr/git-drs/cloud" "github.com/calypr/git-drs/drs" "github.com/calypr/git-drs/drs/hash" - "github.com/calypr/git-drs/s3_utils" ) type DRSClient interface { @@ -53,9 +53,6 @@ type DRSClient interface { // Fields allowed: URLs, authz, name, version, description UpdateRecord(updateInfo *drs.DRSObject, did string) (*drs.DRSObject, error) - // Create a DRS object given file info - BuildDrsObj(fileName string, checksum string, size int64, drsId string) (*drs.DRSObject, error) - // Add an S3 URL to an existing indexd record - AddURL(s3URL, sha256, awsAccessKey, awsSecretKey, regionFlag, endpointFlag string, opts ...s3_utils.AddURLOption) (s3_utils.S3Meta, error) + AddURL(s3URL, sha256, awsAccessKey, awsSecretKey, regionFlag, endpointFlag string, opts ...cloud.AddURLOption) (cloud.S3Meta, error) } diff --git a/client/tests/add-url-helper_test.go b/client/tests/add-url-helper_test.go index 9851d220..5330a8aa 100644 --- a/client/tests/add-url-helper_test.go +++ b/client/tests/add-url-helper_test.go @@ -6,8 +6,8 @@ import ( "testing" "github.com/bytedance/sonic/encoder" + "github.com/calypr/git-drs/cloud" "github.com/calypr/git-drs/drsmap" - "github.com/calypr/git-drs/s3_utils" "github.com/calypr/git-drs/utils" ) @@ -149,8 +149,8 @@ func TestGetBucketDetails_Gen3Success(t *testing.T) { return } - response := s3_utils.S3BucketsResponse{ - S3Buckets: map[string]*s3_utils.S3Bucket{ + response := cloud.S3BucketsResponse{ + S3Buckets: map[string]*cloud.S3Bucket{ "test-bucket": { Region: "us-west-2", EndpointURL: "https://s3.aws.amazon.com", diff --git a/client/tests/add-url_test.go b/client/tests/add-url_test.go index bb7c1e55..b581b02b 100644 --- a/client/tests/add-url_test.go +++ b/client/tests/add-url_test.go @@ -3,7 +3,7 @@ package client import ( "testing" - "github.com/calypr/git-drs/s3_utils" + "github.com/calypr/git-drs/cloud" ) // TestValidateInputs_ValidInputs tests validation with valid S3 URL and SHA256 @@ -36,7 +36,7 @@ func TestValidateInputs_ValidInputs(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := s3_utils.ValidateInputs(tt.s3URL, tt.sha256) + err := cloud.ValidateInputs(tt.s3URL, tt.sha256) if (err != nil) != tt.wantErr { t.Errorf("validateInputs() error = %v, wantErr %v", err, tt.wantErr) } @@ -82,7 +82,7 @@ func TestValidateInputs_InvalidS3URL(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := s3_utils.ValidateInputs(tt.s3URL, validSHA256) + err := cloud.ValidateInputs(tt.s3URL, validSHA256) if (err != nil) != tt.wantErr { t.Errorf("validateInputs() error = %v, wantErr %v", err, tt.wantErr) } @@ -133,7 +133,7 @@ func TestValidateInputs_InvalidSHA256(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := s3_utils.ValidateInputs(validS3URL, tt.sha256) + err := cloud.ValidateInputs(validS3URL, tt.sha256) if (err != nil) != tt.wantErr { t.Errorf("validateInputs() error = %v, wantErr %v", err, tt.wantErr) } @@ -147,7 +147,7 @@ func TestValidateInputs_SHA256Normalization(t *testing.T) { uppercaseSHA256 := "E3B0C44298FC1C149AFBF4C8996FB92427AE41E4649B934CA495991B7852B855" // Should not error on uppercase SHA256 (it gets normalized internally) - err := s3_utils.ValidateInputs(validS3URL, uppercaseSHA256) + err := cloud.ValidateInputs(validS3URL, uppercaseSHA256) if err != nil { t.Errorf("validateInputs() should accept uppercase SHA256, got error: %v", err) } @@ -159,7 +159,7 @@ func TestValidateInputs_HexDecodeValidation(t *testing.T) { // Test valid 64-character hex string validHex := "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" - err := s3_utils.ValidateInputs(validS3URL, validHex) + err := cloud.ValidateInputs(validS3URL, validHex) if err != nil { t.Errorf("validateInputs() error = %v, want nil", err) } @@ -167,7 +167,7 @@ func TestValidateInputs_HexDecodeValidation(t *testing.T) { // Test that hex.DecodeString is properly checked // This has correct length but invalid hex invalidHex := "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" - err = s3_utils.ValidateInputs(validS3URL, invalidHex) + err = cloud.ValidateInputs(validS3URL, invalidHex) if err == nil { t.Errorf("validateInputs() should reject invalid hex, got nil error") } @@ -201,7 +201,7 @@ func TestValidateInputs_CaseSensitivity(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := s3_utils.ValidateInputs(tt.s3URL, validSHA256) + err := cloud.ValidateInputs(tt.s3URL, validSHA256) if (err != nil) != tt.wantErr { t.Errorf("validateInputs() error = %v, wantErr %v", err, tt.wantErr) } @@ -245,7 +245,7 @@ func TestValidateInputs_EdgeCases(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := s3_utils.ValidateInputs(tt.s3URL, tt.sha256) + err := cloud.ValidateInputs(tt.s3URL, tt.sha256) if (err != nil) != tt.wantErr { t.Errorf("validateInputs() error = %v, wantErr %v", err, tt.wantErr) } @@ -260,6 +260,6 @@ func BenchmarkValidateInputs(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - _ = s3_utils.ValidateInputs(s3URL, sha256) + _ = cloud.ValidateInputs(s3URL, sha256) } } diff --git a/cloud/agent_fetch_reader.go b/cloud/agent_fetch_reader.go new file mode 100644 index 00000000..b6e094a4 --- /dev/null +++ b/cloud/agent_fetch_reader.go @@ -0,0 +1,192 @@ +package cloud + +import ( + "context" + "fmt" + "io" + "net/http" + "net/url" + "os" + "strings" + "sync/atomic" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/s3" +) + +// progressReader wraps an io.ReadCloser and periodically writes progress to stderr. +type progressReader struct { + rc io.ReadCloser + label string + start time.Time + total int64 // accessed atomically + quit chan struct{} + done chan struct{} + ticker time.Duration +} + +func newProgressReader(rc io.ReadCloser, label string) io.ReadCloser { + p := &progressReader{ + rc: rc, + label: label, + start: time.Now(), + quit: make(chan struct{}), + done: make(chan struct{}), + ticker: 500 * time.Millisecond, + } + + go func() { + t := time.NewTicker(p.ticker) + defer t.Stop() + var last int64 + for { + select { + case <-t.C: + total := atomic.LoadInt64(&p.total) + elapsed := time.Since(p.start).Seconds() + var rate float64 + if elapsed > 0 { + rate = float64(total) / elapsed + } + // \r to overwrite the same line like git pull; no newline until done + fmt.Fprintf(os.Stderr, "\r%s: %d bytes (%.1f KiB/s)", p.label, total, rate/1024) + last = total + case <-p.quit: + // final line (replace same line, then newline) + total := atomic.LoadInt64(&p.total) + _ = last // in case we want to use last for something later + fmt.Fprintf(os.Stderr, "\r%s: %d bytes\n", p.label, total) + close(p.done) + return + } + } + }() + + return p +} + +func (p *progressReader) Read(b []byte) (int, error) { + n, err := p.rc.Read(b) + if n > 0 { + atomic.AddInt64(&p.total, int64(n)) + } + return n, err +} + +func (p *progressReader) Close() error { + // Close underlying reader first, then stop progress goroutine and wait for completion. + err := p.rc.Close() + close(p.quit) + <-p.done + return err +} + +// AgentFetchReader fetches the object described by `input` and returns an io.ReadCloser. +// It accepts `s3://bucket/key` URLs and converts them to HTTPS URLs. If `input.AWSEndpoint` +// is set it will use that endpoint in path-style (endpoint/bucket/key); otherwise it +// uses the default virtual-hosted AWS form: https://{bucket}.s3.amazonaws.com/{key}. +func AgentFetchReader(ctx context.Context, input S3ObjectParameters) (io.ReadCloser, error) { + if ctx == nil { + ctx = context.Background() + } + + raw := strings.TrimSpace(input.S3URL) + if raw == "" { + return nil, fmt.Errorf("AgentFetchReader: S3ObjectParameters.S3URL is empty") + } + + useSignedFetch := strings.TrimSpace(input.AWSAccessKey) != "" || + strings.TrimSpace(input.AWSSecretKey) != "" || + strings.TrimSpace(input.AWSRegion) != "" + if useSignedFetch { + if strings.TrimSpace(input.AWSAccessKey) == "" || strings.TrimSpace(input.AWSSecretKey) == "" || strings.TrimSpace(input.AWSRegion) == "" { + return nil, fmt.Errorf("AgentFetchReader: AWSAccessKey, AWSSecretKey, and AWSRegion are required for signed fetch") + } + + bucket, key, err := parseS3URL(raw) + if err != nil { + return nil, fmt.Errorf("AgentFetchReader: parse s3 url %q: %w", raw, err) + } + + s3Client, err := newS3Client(ctx, input) + if err != nil { + return nil, fmt.Errorf("AgentFetchReader: init s3 client: %w", err) + } + + out, err := s3Client.GetObject(ctx, &s3.GetObjectInput{ + Bucket: aws.String(bucket), + Key: aws.String(key), + }) + if err != nil { + return nil, fmt.Errorf("AgentFetchReader: s3 GetObject failed (bucket=%q key=%q): %w", bucket, key, err) + } + if out.Body == nil { + return nil, fmt.Errorf("AgentFetchReader: response body is nil for s3://%s/%s", bucket, key) + } + + label := fmt.Sprintf("fetch s3://%s/%s", bucket, key) + return newProgressReader(out.Body, label), nil + } + + u, err := url.Parse(raw) + if err != nil { + return nil, fmt.Errorf("AgentFetchReader: parse url %q: %w", raw, err) + } + + var s3url string + switch u.Scheme { + case "s3": + bucket := u.Host + key := strings.TrimPrefix(u.Path, "/") + if bucket == "" || key == "" { + return nil, fmt.Errorf("AgentFetchReader: invalid s3 URL %q", raw) + } + if ep := strings.TrimSpace(input.AWSEndpoint); ep != "" { + // ensure endpoint has a scheme + if !strings.HasPrefix(ep, "http://") && !strings.HasPrefix(ep, "https://") { + ep = "https://" + ep + } + s3url = strings.TrimRight(ep, "/") + "/" + bucket + "/" + key + } else { + s3url = fmt.Sprintf("https://%s.s3.amazonaws.com/%s", bucket, key) + } + case "", "http", "https": + // allow bare host/path (no scheme) by assuming https, otherwise use as-is + if u.Scheme == "" { + s3url = "https://" + raw + } else { + s3url = raw + } + default: + return nil, fmt.Errorf("AgentFetchReader: unsupported URL scheme %q", u.Scheme) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, s3url, nil) + if err != nil { + return nil, fmt.Errorf("AgentFetchReader: create request: %w", err) + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, fmt.Errorf("AgentFetchReader: http get %s: %w", s3url, err) + } + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + if _, copyErr := io.Copy(io.Discard, resp.Body); copyErr != nil { + _ = resp.Body.Close() + return nil, fmt.Errorf("AgentFetchReader: unexpected status %d fetching %s; failed to drain body: %w", resp.StatusCode, s3url, copyErr) + } + if closeErr := resp.Body.Close(); closeErr != nil { + return nil, fmt.Errorf("AgentFetchReader: unexpected status %d fetching %s; failed to close body: %w", resp.StatusCode, s3url, closeErr) + } + return nil, fmt.Errorf("AgentFetchReader: unexpected status %d fetching %s", resp.StatusCode, s3url) + } + + if resp.Body == nil { + return nil, fmt.Errorf("AgentFetchReader: response body is nil for %s", s3url) + } + // wrap response body with progress reporting that writes to stderr + label := fmt.Sprintf("fetch %s", s3url) + return newProgressReader(resp.Body, label), nil +} diff --git a/s3_utils/download.go b/cloud/download.go similarity index 98% rename from s3_utils/download.go rename to cloud/download.go index a2571ab9..e3b057ef 100644 --- a/s3_utils/download.go +++ b/cloud/download.go @@ -1,4 +1,4 @@ -package s3_utils +package cloud import ( "fmt" diff --git a/s3_utils/download_test.go b/cloud/download_test.go similarity index 98% rename from s3_utils/download_test.go rename to cloud/download_test.go index 507146f0..bb13bc45 100644 --- a/s3_utils/download_test.go +++ b/cloud/download_test.go @@ -1,4 +1,4 @@ -package s3_utils +package cloud import ( "net/http" diff --git a/cloud/downloader.go b/cloud/downloader.go new file mode 100644 index 00000000..19c0703f --- /dev/null +++ b/cloud/downloader.go @@ -0,0 +1,93 @@ +package cloud + +import ( + "context" + "crypto/sha256" + "fmt" + "io" + "os" + "path/filepath" +) + +// Download downloads the S3 object to a temporary file while computing its SHA256 hash. +// returns the computed SHA256 hash, temporary path and any error encountered. +func Download(ctx context.Context, info *S3Object, s3Input S3ObjectParameters, lfsRoot string) (string, string, error) { + // 2) object destination + etag := info.ETag + subdir1, subdir2 := "xx", "yy" + if len(etag) >= 4 { + subdir1 = etag[0:2] + subdir2 = etag[2:4] + } + objName := etag + if objName == "" { + objName = "unknown-etag" + } + tmpDir := filepath.Join(lfsRoot, "tmp-objects", subdir1, subdir2) + tmpObj := filepath.Join(tmpDir, objName) + + // 3) fetch bytes -> tmp, compute sha+count + + // Create the temporary directory and file where the S3 object will be streamed while computing its hash and size. + if err := os.MkdirAll(tmpDir, 0755); err != nil { + return "", "", fmt.Errorf("mkdir %s: %w", tmpDir, err) + } + + f, err := os.Create(tmpObj) + if err != nil { + return "", "", fmt.Errorf("create %s: %w", tmpObj, err) + } + // ensure any leftover file is closed and error propagated via named return + defer func() { + if f != nil { + if cerr := f.Close(); cerr != nil && err == nil { + err = fmt.Errorf("close tmp file: %w", cerr) + } + } + }() + + h := sha256.New() + + var reader io.ReadCloser + reader, err = AgentFetchReader(ctx, s3Input) + if err != nil { + return "", "", fmt.Errorf("fetch reader: %w", err) + } + // ensure close on any early return; propagate close error via named return + defer func() { + if reader != nil { + if cerr := reader.Close(); cerr != nil && err == nil { + err = fmt.Errorf("close reader: %w", cerr) + } + } + }() + + n, err := io.Copy(io.MultiWriter(f, h), reader) + if err != nil { + return "", "", fmt.Errorf("copy bytes to %s: %w", tmpObj, err) + } + + // explicitly close reader and handle error + if cerr := reader.Close(); cerr != nil { + return "", "", fmt.Errorf("close reader: %w", cerr) + } + reader = nil + + // ensure data is flushed to disk + if err = f.Sync(); err != nil { + return "", "", fmt.Errorf("sync %s: %w", tmpObj, err) + } + + // explicitly close tmp file before rename + if cerr := f.Close(); cerr != nil { + return "", "", fmt.Errorf("close %s: %w", tmpObj, cerr) + } + f = nil + + // use n (bytes written) to avoid unused var warnings + _ = n + + // compute hex SHA256 of the fetched content + computedSHA := fmt.Sprintf("%x", h.Sum(nil)) + return computedSHA, tmpObj, nil +} diff --git a/cloud/inspect.go b/cloud/inspect.go new file mode 100644 index 00000000..07351088 --- /dev/null +++ b/cloud/inspect.go @@ -0,0 +1,261 @@ +// Package lfss3 provides a small helper for Git-LFS + S3 object introspection. +// +// It: +// 1. determines the effective Git LFS storage root (.git/lfs vs git config lfs.storage) +// 2. derives a working-tree filename from the S3 object key (basename of key) +// 3. performs an S3 HEAD Object to retrieve size and user metadata (sha256 if present) +package cloud + +import ( + "context" + "crypto/tls" + "errors" + "fmt" + "net/http" + "net/url" + "path" + "regexp" + "strings" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/s3" +) + +// S3ObjectParameters container for S3 object identification and access. +type S3ObjectParameters struct { + S3URL string + AWSAccessKey string + AWSSecretKey string + AWSRegion string + AWSEndpoint string // optional: custom endpoint (Ceph/MinIO/etc.) + SHA256 string // optional expected hex (64 chars). Can be "sha256:" or "" + DestinationPath string // optional override URL path (worktree filename) +} + +// S3Object is what we return. +type S3Object struct { + + // Object identity + Bucket string + Key string + Path string // basename of Key (filename), or override from input + + // HEAD-derived info + SizeBytes int64 + MetaSHA256 string // from user-defined object metadata (if present) + ETag string + LastModTime time.Time +} + +// InspectS3ForLFS does all 3 requested tasks. +func InspectS3ForLFS(ctx context.Context, in S3ObjectParameters) (*S3Object, error) { + if strings.TrimSpace(in.S3URL) == "" { + return nil, errors.New("S3URL is required") + } + if strings.TrimSpace(in.AWSRegion) == "" { + return nil, errors.New("AWSRegion is required") + } + if in.AWSAccessKey == "" || in.AWSSecretKey == "" { + return nil, errors.New("AWSAccessKey and AWSSecretKey are required") + } + + // 2) Parse S3 URL + derive working tree filename. + bucket, key, err := parseS3URL(in.S3URL) + if err != nil { + return nil, err + } + worktreeName := strings.TrimSpace(in.DestinationPath) + if worktreeName == "" { + worktreeName = path.Base(key) + if worktreeName == "." || worktreeName == "/" || worktreeName == "" { + return nil, fmt.Errorf("could not derive worktree name from key %q", key) + } + } else if worktreeName == "." || worktreeName == "/" { + return nil, fmt.Errorf("invalid worktree name override %q", worktreeName) + } + + // 3) HEAD on S3 to determine size and meta.SHA256. + s3Client, err := newS3Client(ctx, in) + if err != nil { + return nil, err + } + head, err := s3Client.HeadObject(ctx, &s3.HeadObjectInput{ + Bucket: aws.String(bucket), + Key: aws.String(key), + }) + if err != nil { + return nil, fmt.Errorf("s3 HeadObject failed (bucket=%q key=%q): %w", bucket, key, err) + } + + metaSHA := extractSHA256FromMetadata(head.Metadata) + + // Optional: validate provided SHA256 against metadata if both exist. + expected := normalizeSHA256(in.SHA256) + if expected != "" && metaSHA != "" && !strings.EqualFold(expected, metaSHA) { + return nil, fmt.Errorf("sha256 mismatch: expected=%s head.meta=%s", expected, metaSHA) + } + + var lm time.Time + if head.LastModified != nil { + lm = *head.LastModified + } + + if head.ContentLength == nil { + return nil, fmt.Errorf("s3 HeadObject missing ContentLength (bucket=%q key=%q)", bucket, key) + } + sizeBytes := *head.ContentLength + + var etag string + if head.ETag != nil { + etag = strings.Trim(*head.ETag, `"`) + } + + out := &S3Object{ + Bucket: bucket, + Key: key, + Path: worktreeName, + SizeBytes: sizeBytes, + MetaSHA256: metaSHA, + ETag: etag, + LastModTime: lm, + } + return out, nil +} + +// +// --- S3 parsing + client --- +// + +var virtualHostedRE = regexp.MustCompile(`^(.+?)\.s3(?:[.-]|$)`) + +// parseS3URL parses s3://bucket/key, virtual-hosted HTTPS (bucket.s3.../key) +// and path-style HTTPS (s3.../bucket/key). Returns bucket and key. +func parseS3URL(raw string) (string, string, error) { + u, err := url.Parse(raw) + if err != nil { + return "", "", err + } + + switch u.Scheme { + case "s3": + bucket := u.Host + key := strings.TrimPrefix(u.Path, "/") + return bucket, key, nil + case "http", "https": + host := u.Hostname() + + // virtual-hosted: bucket.s3.amazonaws.com or bucket.s3-region.amazonaws.com + if m := virtualHostedRE.FindStringSubmatch(host); m != nil { + bucket := m[1] + key := strings.TrimPrefix(u.Path, "/") + return bucket, key, nil + } + + // path-style: s3.../bucket/key + path := strings.TrimPrefix(u.Path, "/") + if path == "" { + return "", "", fmt.Errorf("no bucket in URL: %s", raw) + } + parts := strings.SplitN(path, "/", 2) + bucket := parts[0] + key := "" + if len(parts) == 2 { + key = parts[1] + } + return bucket, key, nil + default: + return "", "", fmt.Errorf("unsupported scheme: %s", u.Scheme) + } +} + +func newS3Client(ctx context.Context, in S3ObjectParameters) (*s3.Client, error) { + creds := credentials.NewStaticCredentialsProvider(in.AWSAccessKey, in.AWSSecretKey, "") + + // Custom HTTP client is useful for S3-compatible endpoints. + httpClient := &http.Client{ + Timeout: 60 * time.Second, + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + MinVersion: tls.VersionTLS12, + }, + }, + } + + cfg, err := config.LoadDefaultConfig(ctx, + config.WithRegion(in.AWSRegion), + config.WithCredentialsProvider(creds), + config.WithHTTPClient(httpClient), + ) + if err != nil { + return nil, fmt.Errorf("aws config init failed: %w", err) + } + + opts := []func(*s3.Options){} + if strings.TrimSpace(in.AWSEndpoint) != "" { + ep := strings.TrimRight(in.AWSEndpoint, "/") + opts = append(opts, func(o *s3.Options) { + o.UsePathStyle = true // usually required for Ceph/MinIO/custom endpoints + o.BaseEndpoint = aws.String(ep) + }) + } + + return s3.NewFromConfig(cfg, opts...), nil +} + +// +// --- SHA256 metadata extraction --- +// + +var sha256HexRe = regexp.MustCompile(`(?i)^[0-9a-f]{64}$`) + +func normalizeSHA256(s string) string { + s = strings.TrimSpace(s) + s = strings.TrimPrefix(strings.ToLower(s), "sha256:") + s = strings.TrimSpace(s) + if s == "" { + return "" + } + if !sha256HexRe.MatchString(s) { + // If caller provided something malformed, treat as empty. + // Change this to a hard error if you prefer. + return "" + } + return strings.ToLower(s) +} + +func extractSHA256FromMetadata(md map[string]string) string { + if md == nil { + return "" + } + + // AWS SDK v2 exposes user-defined metadata WITHOUT the "x-amz-meta-" prefix, + // and normalizes keys to lower-case. + candidates := []string{ + "sha256", + "checksum-sha256", + "content-sha256", + "oid-sha256", + "git-lfs-sha256", + } + + for _, k := range candidates { + if v, ok := md[k]; ok { + n := normalizeSHA256(v) + if n != "" { + return n + } + } + } + + // Sometimes people stash "sha256:" + for _, v := range md { + if n := normalizeSHA256(v); n != "" { + return n + } + } + + return "" +} diff --git a/cloud/inspect_test.go b/cloud/inspect_test.go new file mode 100644 index 00000000..77a5d449 --- /dev/null +++ b/cloud/inspect_test.go @@ -0,0 +1,123 @@ +package cloud + +import ( + "os" + "os/exec" + "strings" + "testing" +) + +func TestParseS3URL_S3Scheme(t *testing.T) { + b, k, err := parseS3URL("s3://my-bucket/path/to/file.bam") + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if b != "my-bucket" { + t.Fatalf("bucket mismatch: %q", b) + } + if k != "path/to/file.bam" { + t.Fatalf("key mismatch: %q", k) + } +} + +func TestParseS3URL_HTTPSPathStyle(t *testing.T) { + b, k, err := parseS3URL("https://s3.example.org/my-bucket/path/to/file.bam") + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if b != "my-bucket" { + t.Fatalf("bucket mismatch: %q", b) + } + if k != "path/to/file.bam" { + t.Fatalf("key mismatch: %q", k) + } +} + +func TestParseS3URL_HTTPSVirtualHosted(t *testing.T) { + b, k, err := parseS3URL("https://my-bucket.s3.amazonaws.com/path/to/file.bam") + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if b != "my-bucket" { + t.Fatalf("bucket mismatch: %q", b) + } + if k != "path/to/file.bam" { + t.Fatalf("key mismatch: %q", k) + } +} + +func TestNormalizeSHA256(t *testing.T) { + hex := "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" + + if got := normalizeSHA256(hex); got != hex { + t.Fatalf("expected %q, got %q", hex, got) + } + + if got := normalizeSHA256("sha256:" + strings.ToUpper(hex)); got != hex { + t.Fatalf("expected %q, got %q", hex, got) + } + + if got := normalizeSHA256("not-a-sha"); got != "" { + t.Fatalf("expected empty for invalid, got %q", got) + } +} + +func TestExtractSHA256FromMetadata_ByKey(t *testing.T) { + hex := "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" + + md := map[string]string{ + "sha256": hex, + } + got := extractSHA256FromMetadata(md) + if got != hex { + t.Fatalf("expected %q, got %q", hex, got) + } +} + +func TestExtractSHA256FromMetadata_ByAlternateKey(t *testing.T) { + hex := "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" + + md := map[string]string{ + "checksum-sha256": "sha256:" + hex, + } + got := extractSHA256FromMetadata(md) + if got != hex { + t.Fatalf("expected %q, got %q", hex, got) + } +} + +func TestExtractSHA256FromMetadata_SearchValues(t *testing.T) { + hex := "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" + + md := map[string]string{ + "something": "sha256:" + hex, + } + got := extractSHA256FromMetadata(md) + if got != hex { + t.Fatalf("expected %q, got %q", hex, got) + } +} + +// --- test helpers --- + +func mustRun(t *testing.T, dir string, name string, args ...string) { + t.Helper() + cmd := exec.Command(name, args...) + cmd.Dir = dir + out, err := cmd.CombinedOutput() + if err != nil { + t.Fatalf("command failed: %s %v\nerr=%v\nout=%s", name, args, err, string(out)) + } +} + +func mustChdir(t *testing.T, dir string) string { + t.Helper() + old, err := os.Getwd() + if err != nil { + t.Fatalf("Getwd: %v", err) + } + if err := os.Chdir(dir); err != nil { + t.Fatalf("Chdir(%s): %v", dir, err) + } + return old +} diff --git a/s3_utils/s3.go b/cloud/s3.go similarity index 99% rename from s3_utils/s3.go rename to cloud/s3.go index 8b83bfbe..44dc868b 100644 --- a/s3_utils/s3.go +++ b/cloud/s3.go @@ -1,4 +1,4 @@ -package s3_utils +package cloud import ( "log/slog" diff --git a/s3_utils/s3_test.go b/cloud/s3_test.go similarity index 98% rename from s3_utils/s3_test.go rename to cloud/s3_test.go index 42fda897..b04d4a59 100644 --- a/s3_utils/s3_test.go +++ b/cloud/s3_test.go @@ -1,4 +1,4 @@ -package s3_utils +package cloud import ( "net/http" diff --git a/s3_utils/validate.go b/cloud/validate.go similarity index 97% rename from s3_utils/validate.go rename to cloud/validate.go index f67fcd9b..c876ea9d 100644 --- a/s3_utils/validate.go +++ b/cloud/validate.go @@ -1,4 +1,4 @@ -package s3_utils +package cloud import ( "encoding/hex" diff --git a/s3_utils/validate_test.go b/cloud/validate_test.go similarity index 95% rename from s3_utils/validate_test.go rename to cloud/validate_test.go index d2493d98..2c08270b 100644 --- a/s3_utils/validate_test.go +++ b/cloud/validate_test.go @@ -1,4 +1,4 @@ -package s3_utils +package cloud import "testing" diff --git a/cmd/addref/add-ref.go b/cmd/addref/add-ref.go index e4432002..27756e56 100644 --- a/cmd/addref/add-ref.go +++ b/cmd/addref/add-ref.go @@ -8,7 +8,7 @@ import ( "github.com/calypr/git-drs/config" "github.com/calypr/git-drs/drs/hash" "github.com/calypr/git-drs/drslog" - "github.com/calypr/git-drs/drsmap" + drslfs "github.com/calypr/git-drs/drsmap/lfs" "github.com/spf13/cobra" ) @@ -62,7 +62,7 @@ var Cmd = &cobra.Command{ os.MkdirAll(dirPath, os.ModePerm) } - err = drsmap.CreateLfsPointer(obj, dstPath) + err = drslfs.CreateLfsPointer(obj, dstPath) return err }, } diff --git a/cmd/addurl/main.go b/cmd/addurl/main.go index 12427369..d696ac17 100644 --- a/cmd/addurl/main.go +++ b/cmd/addurl/main.go @@ -1,119 +1,673 @@ package addurl import ( + "context" + "crypto/sha256" + "encoding/json" "errors" "fmt" + "log/slog" + "net/url" "os" - "os/exec" "path/filepath" + "sort" + "strings" + "time" + "github.com/calypr/git-drs/cloud" "github.com/calypr/git-drs/config" + "github.com/calypr/git-drs/drs" "github.com/calypr/git-drs/drslog" - "github.com/calypr/git-drs/s3_utils" + "github.com/calypr/git-drs/drsmap" + drslfs "github.com/calypr/git-drs/drsmap/lfs" + "github.com/calypr/git-drs/lfs" + "github.com/calypr/git-drs/precommit_cache" "github.com/calypr/git-drs/utils" "github.com/spf13/cobra" ) -// AddURLCmd represents the add-url command -var AddURLCmd = &cobra.Command{ - Use: "add-url ", - Short: "Add a file to the Git DRS repo using an S3 URL", - Args: func(cmd *cobra.Command, args []string) error { - if len(args) != 2 { - cmd.SilenceUsage = false - return fmt.Errorf("error: requires exactly 2 arguments (S3 URL and SHA256), received %d\n\nUsage: %s\n\nSee 'git drs add-url --help' for more details", len(args), cmd.UseLine()) - } - return nil - }, - RunE: func(cmd *cobra.Command, args []string) error { - myLogger := drslog.GetLogger() - - // set git config lfs.allowincompletepush = true - configCmd := exec.Command("git", "config", "lfs.allowincompletepush", "true") - if err := configCmd.Run(); err != nil { - return fmt.Errorf("unable to configure git to push pointers: %v. Please change the .git/config file to include an [lfs] section with allowincompletepush = true", err) - } +var Cmd = NewCommand() - // Parse arguments - s3URL := args[0] - sha256 := args[1] - awsAccessKey, _ := cmd.Flags().GetString(s3_utils.AWS_KEY_FLAG_NAME) - awsSecretKey, _ := cmd.Flags().GetString(s3_utils.AWS_SECRET_FLAG_NAME) - awsRegion, _ := cmd.Flags().GetString(s3_utils.AWS_REGION_FLAG_NAME) - awsEndpoint, _ := cmd.Flags().GetString(s3_utils.AWS_ENDPOINT_URL_FLAG_NAME) - remote, _ := cmd.Flags().GetString("remote") - - // if providing credentials, access key and secret must both be provided - if (awsAccessKey == "" && awsSecretKey != "") || (awsAccessKey != "" && awsSecretKey == "") { - return errors.New("incomplete credentials provided as environment variables. Please run `export " + s3_utils.AWS_KEY_ENV_VAR + "=` and `export " + s3_utils.AWS_SECRET_ENV_VAR + "=` to configure both") - } +// NewCommand constructs the Cobra command for the `add-url` subcommand, +// wiring usage, argument validation and the RunE handler. +func NewCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: "add-url [path]", + Short: "Add a file to the Git DRS repo using an S3 URL", + Args: func(cmd *cobra.Command, args []string) error { + if len(args) < 1 || len(args) > 2 { + return errors.New("usage: add-url [path]") + } + return nil + }, + RunE: runAddURL, + } + addFlags(cmd) + return cmd +} - // if none provided, use default AWS configuration on file - if awsAccessKey == "" && awsSecretKey == "" { - myLogger.Debug("No AWS credentials provided. Using default AWS configuration from file.") - } +// addFlags registers command-line flags for AWS credentials, endpoint and an +// optional `sha256` expected checksum. +func addFlags(cmd *cobra.Command) { + cmd.Flags().String( + cloud.AWS_KEY_FLAG_NAME, + os.Getenv(cloud.AWS_KEY_ENV_VAR), + "AWS access key", + ) - cfg, err := config.LoadConfig() - if err != nil { - return fmt.Errorf("error loading config: %v", err) + cmd.Flags().String( + cloud.AWS_SECRET_FLAG_NAME, + os.Getenv(cloud.AWS_SECRET_ENV_VAR), + "AWS secret key", + ) + + cmd.Flags().String( + cloud.AWS_REGION_FLAG_NAME, + os.Getenv(cloud.AWS_REGION_ENV_VAR), + "AWS S3 region", + ) + + cmd.Flags().String( + cloud.AWS_ENDPOINT_URL_FLAG_NAME, + os.Getenv(cloud.AWS_ENDPOINT_URL_ENV_VAR), + "AWS S3 endpoint (optional, for Ceph/MinIO)", + ) + + // New flag: optional expected SHA256 + cmd.Flags().String( + "sha256", + "", + "Expected SHA256 checksum (optional)", + ) +} + +// runAddURL is the Cobra RunE wrapper that delegates execution to the +func runAddURL(cmd *cobra.Command, args []string) (err error) { + return NewAddURLService().Run(cmd, args) +} + +// download uses cloud.AgentFetchReader to download the S3 object, returning +// the computed SHA256 and the path to the temporary downloaded file. +// The caller is responsible for moving/deleting the temporary file. +// we include this wrapper function to allow mocking in tests. +var download = func(ctx context.Context, info *cloud.S3Object, input cloud.S3ObjectParameters, lfsRoot string) (string, string, error) { + return cloud.Download(ctx, info, input, lfsRoot) +} + +// AddURLService groups injectable dependencies used to implement the add-url +// behavior (logger factory, S3 inspection, LFS helpers, config loader, etc.). +type AddURLService struct { + newLogger func(string, bool) (*slog.Logger, error) + inspectS3 func(ctx context.Context, input cloud.S3ObjectParameters) (*cloud.S3Object, error) + isLFSTracked func(path string) (bool, error) + getGitRoots func(ctx context.Context) (string, string, error) + gitLFSTrack func(ctx context.Context, path string) (bool, error) + loadConfig func() (*config.Config, error) + download func(ctx context.Context, info *cloud.S3Object, input cloud.S3ObjectParameters, lfsRoot string) (string, string, error) +} + +// NewAddURLService constructs an AddURLService populated with production +// implementations of its dependencies. +func NewAddURLService() *AddURLService { + return &AddURLService{ + newLogger: drslog.NewLogger, + inspectS3: cloud.InspectS3ForLFS, + isLFSTracked: lfs.IsLFSTracked, + getGitRoots: lfs.GetGitRootDirectories, + gitLFSTrack: lfs.GitLFSTrackReadOnly, + loadConfig: config.LoadConfig, + download: download, + } +} + +// Run executes the add-url workflow: parse CLI input, inspect the S3 object, +// ensure the LFS object exists in local storage, write a Git LFS pointer file, +// update the pre-commit cache (best-effort), optionally add a git-lfs track +// entry, and record the DRS mapping. +func (s *AddURLService) Run(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + if ctx == nil { + ctx = context.Background() + } + + logger, err := s.newLogger("", false) + if err != nil { + return fmt.Errorf("error creating logger: %v", err) + } + + input, err := parseAddURLInput(cmd, args) + if err != nil { + return err + } + + s3Info, err := s.inspectS3(ctx, input.s3Params) + if err != nil { + return err + } + + isTracked, err := s.isLFSTracked(input.path) + if err != nil { + return fmt.Errorf("check LFS tracking for %s: %w", input.path, err) + } + + gitCommonDir, lfsRoot, err := s.getGitRoots(ctx) + if err != nil { + return fmt.Errorf("get git root directories: %w", err) + } + + if err := printResolvedInfo(cmd, gitCommonDir, lfsRoot, s3Info, input.path, isTracked, input.sha256); err != nil { + return err + } + + oid, err := s.ensureLFSObject(ctx, s3Info, input, lfsRoot) + if err != nil { + return err + } + + if err := writePointerFile(input.path, oid, s3Info.SizeBytes); err != nil { + return err + } + + if err := updatePrecommitCache(ctx, logger, input.path, oid, input.s3URL); err != nil { + logger.Warn("pre-commit cache update skipped", "error", err) + } + + if err := maybeTrackLFS(ctx, s.gitLFSTrack, input.path, isTracked); err != nil { + return err + } + + cfg, err := s.loadConfig() + if err != nil { + return fmt.Errorf("error getting config: %v", err) + } + + remote, err := cfg.GetDefaultRemote() + if err != nil { + return err + } + + remoteConfig := cfg.GetRemote(remote) + if remoteConfig == nil { + return fmt.Errorf("error getting remote configuration for %s", remote) + } + + builder := drs.NewObjectBuilder(remoteConfig.GetBucketName(), remoteConfig.GetProjectId()) + + file := drslfs.LfsFileInfo{ + Name: input.path, + Size: s3Info.SizeBytes, + Oid: oid, + } + if _, err := drsmap.WriteDrsFile(builder, file, &input.s3URL); err != nil { + return fmt.Errorf("error WriteDrsFile: %v", err) + } + + return nil +} + +// updatePrecommitCache updates the project's pre-commit cache with a mapping +// from a repository-relative `pathArg` to the given LFS `oid` and records the +// external source URL. It will: +// - require a non-nil `logger` +// - open the pre-commit cache (`precommit_cache.Open`) +// - ensure cache directories exist +// - convert the supplied worktree path to a repository-relative path +// - create or update the per-path JSON entry with the current OID and timestamp +// - create or update the per-OID JSON entry listing paths that reference it, +// the external URL, and a content-change flag when the path's OID changed +// - remove the path from the previous OID entry when the content changed +// +// Parameters: +// - ctx: context for operations that may be cancellable +// - logger: a non-nil `*slog.Logger` used for warnings; if nil the function +// returns an error +// - pathArg: the worktree path to record (absolute or relative); must not be empty +// - oid: the LFS object id (string) to associate with the path +// - externalURL: optional external source URL for the object; empty string is allowed +// +// Returns an error if any cache operation, path resolution, or I/O fails. +func updatePrecommitCache(ctx context.Context, logger *slog.Logger, pathArg, oid, externalURL string) error { + if logger == nil { + return errors.New("logger is required") + } + // Open pre-commit cache. Returns a configured Cache or error. + cache, err := precommit_cache.Open(ctx) + if err != nil { + return err + } + + // Ensure cache directories exist. + if err := ensureCacheDirs(cache, logger); err != nil { + return err + } + + // Convert worktree path to repository-relative path. + relPath, err := repoRelativePath(pathArg) + if err != nil { + return err + } + + // Current timestamp in RFC3339 format (UTC). + now := time.Now().UTC().Format(time.RFC3339) + + // Read previous path entry, if any. + pathFile := cachePathEntryFile(cache, relPath) + prevEntry, prevExists, err := readPathEntry(pathFile) + if err != nil { + return err + } + // track whether content changed for this path + contentChanged := prevExists && prevEntry.LFSOID != "" && prevEntry.LFSOID != oid + + if err := writeJSONAtomic(pathFile, precommit_cache.PathEntry{ + Path: relPath, + LFSOID: oid, + UpdatedAt: now, + }); err != nil { + return err + } + + if err := upsertOIDEntry(cache, oid, relPath, externalURL, now, contentChanged); err != nil { + return err + } + + if contentChanged { + _ = removePathFromOID(cache, prevEntry.LFSOID, relPath, now) + } + + return nil +} + +// ensureCacheDirs verifies and creates the pre-commit cache directory layout +// (paths and oids directories). It logs a warning when creating a missing +// cache root. +func ensureCacheDirs(cache *precommit_cache.Cache, logger *slog.Logger) error { + if cache == nil { + return errors.New("cache is nil") + } + if _, err := os.Stat(cache.Root); err != nil { + if os.IsNotExist(err) { + logger.Warn("pre-commit cache directory missing; creating", "path", cache.Root) + } else { + return err } + } + if err := os.MkdirAll(cache.PathsDir, 0o755); err != nil { + return fmt.Errorf("create cache paths dir: %w", err) + } + if err := os.MkdirAll(cache.OIDsDir, 0o755); err != nil { + return fmt.Errorf("create cache oids dir: %w", err) + } + return nil +} - remoteName, err := cfg.GetRemoteOrDefault(remote) +// repoRelativePath converts a worktree path (absolute or relative) to a +// repository-relative path. It resolves symlinks and ensures the path is +// contained within the repository root. +func repoRelativePath(pathArg string) (string, error) { + if pathArg == "" { + return "", errors.New("empty worktree path") + } + root, err := utils.GitTopLevel() + if err != nil { + return "", err + } + root, err = filepath.EvalSymlinks(root) + if err != nil { + return "", err + } + clean := filepath.Clean(pathArg) + if filepath.IsAbs(clean) { + clean, err = filepath.EvalSymlinks(clean) if err != nil { - return fmt.Errorf("error getting default remote: %v", err) + return "", err } - - drsClient, err := cfg.GetRemoteClient(remoteName, myLogger) + rel, err := filepath.Rel(root, clean) if err != nil { - return fmt.Errorf("error getting current remote client: %v", err) + return "", err + } + if strings.HasPrefix(rel, "..") { + return "", fmt.Errorf("path %s is outside repo root %s", clean, root) } + return filepath.ToSlash(rel), nil + } + return filepath.ToSlash(clean), nil +} - // Call client.AddURL to handle Gen3 interactions - meta, err := drsClient.AddURL(s3URL, sha256, awsAccessKey, awsSecretKey, awsRegion, awsEndpoint) - if err != nil { - return err +// cachePathEntryFile returns the filesystem path to the JSON path-entry file +// for the given repository-relative path within the provided Cache. +func cachePathEntryFile(cache *precommit_cache.Cache, path string) string { + return filepath.Join(cache.PathsDir, precommit_cache.EncodePath(path)+".json") +} + +// cacheOIDEntryFile returns the filesystem path to the JSON OID-entry file +// for the given LFS OID. The file is named by sha256(oid) to avoid filesystem +// restrictions and collisions. +func cacheOIDEntryFile(cache *precommit_cache.Cache, oid string) string { + sum := sha256.Sum256([]byte(oid)) + return filepath.Join(cache.OIDsDir, fmt.Sprintf("%x.json", sum[:])) +} + +// readPathEntry reads and parses a JSON PathEntry from disk. It returns the +// parsed entry, a boolean indicating existence, or an error on I/O/parse +// failure. +func readPathEntry(path string) (*precommit_cache.PathEntry, bool, error) { + data, err := os.ReadFile(path) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + return nil, false, nil } + return nil, false, err + } + var entry precommit_cache.PathEntry + if err := json.Unmarshal(data, &entry); err != nil { + return nil, false, err + } + return &entry, true, nil +} - // Generate and add pointer file - _, relFilePath, err := utils.ParseS3URL(s3URL) - if err != nil { - return fmt.Errorf("failed to parse S3 URL: %w", err) +// readOIDEntry reads and parses a JSON OIDEntry from disk. If the file is +// missing it returns a freshly initialized entry (with LFSOID set to the +// supplied oid and UpdatedAt set to now). +func readOIDEntry(path string, oid string, now string) (*precommit_cache.OIDEntry, error) { + entry := &precommit_cache.OIDEntry{ + LFSOID: oid, + Paths: []string{}, + UpdatedAt: now, + } + data, err := os.ReadFile(path) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + return entry, nil } - if err := generatePointerFile(relFilePath, sha256, meta.Size); err != nil { - return fmt.Errorf("failed to generate pointer file: %w", err) + return nil, err + } + if err := json.Unmarshal(data, entry); err != nil { + return nil, err + } + entry.LFSOID = oid + return entry, nil +} + +// upsertOIDEntry creates or updates the OID entry for `oid`, ensuring `path` +// is listed among its Paths, updating ExternalURL when provided, and setting +// content-change/state fields. The updated entry is written atomically. +func upsertOIDEntry(cache *precommit_cache.Cache, oid, path, externalURL, now string, contentChanged bool) error { + if cache == nil { + return errors.New("cache is nil") + } + oidFile := cacheOIDEntryFile(cache, oid) + entry, err := readOIDEntry(oidFile, oid, now) + if err != nil { + return err + } + + pathSet := make(map[string]struct{}, len(entry.Paths)+1) + for _, p := range entry.Paths { + pathSet[p] = struct{}{} + } + if path != "" { + pathSet[path] = struct{}{} + } + entry.Paths = sortedKeys(pathSet) + entry.UpdatedAt = now + entry.ContentChange = entry.ContentChange || contentChanged + if strings.TrimSpace(externalURL) != "" { + entry.ExternalURL = externalURL + } + + return writeJSONAtomic(oidFile, entry) +} + +// removePathFromOID removes `path` from the OID entry for `oid` and writes +// the updated entry atomically. Missing entries are treated as empty. +// sortedKeys returns a sorted slice of keys from the provided string-set map. +func removePathFromOID(cache *precommit_cache.Cache, oid, path, now string) error { + if cache == nil { + return errors.New("cache is nil") + } + oidFile := cacheOIDEntryFile(cache, oid) + entry, err := readOIDEntry(oidFile, oid, now) + if err != nil { + return err + } + pathSet := make(map[string]struct{}, len(entry.Paths)) + for _, p := range entry.Paths { + if p == path { + continue } - myLogger.Debug("S3 URL successfully added to Git DRS repo.") - return nil - }, + pathSet[p] = struct{}{} + } + entry.Paths = sortedKeys(pathSet) + entry.UpdatedAt = now + + return writeJSONAtomic(oidFile, entry) +} + +// sortedKeys returns a sorted slice of keys from the provided string-set map. +func sortedKeys(set map[string]struct{}) []string { + keys := make([]string, 0, len(set)) + for key := range set { + keys = append(keys, key) + } + sort.Strings(keys) + return keys +} + +// writeJSONAtomic marshals `value` to JSON and writes it to `path` atomically +// by writing to a temporary file in the same directory and renaming it. It +// ensures parent directories exist. +func writeJSONAtomic(path string, value any) error { + dir := filepath.Dir(path) + if err := os.MkdirAll(dir, 0o755); err != nil { + return err + } + tmp := path + ".tmp" + data, err := json.Marshal(value) + if err != nil { + return err + } + if err := os.WriteFile(tmp, data, 0o644); err != nil { + return err + } + return os.Rename(tmp, path) } -func init() { - AddURLCmd.Flags().String(s3_utils.AWS_KEY_FLAG_NAME, os.Getenv(s3_utils.AWS_KEY_ENV_VAR), "AWS access key") - AddURLCmd.Flags().String(s3_utils.AWS_SECRET_FLAG_NAME, os.Getenv(s3_utils.AWS_SECRET_ENV_VAR), "AWS secret key") - AddURLCmd.Flags().String(s3_utils.AWS_REGION_FLAG_NAME, os.Getenv(s3_utils.AWS_REGION_ENV_VAR), "AWS S3 region") - AddURLCmd.Flags().String(s3_utils.AWS_ENDPOINT_URL_FLAG_NAME, os.Getenv(s3_utils.AWS_ENDPOINT_URL_ENV_VAR), "AWS S3 endpoint") - AddURLCmd.Flags().String("remote", "", "target remote DRS server (default: default_remote)") +// parseAddURLInput parses CLI args and flags into an addURLInput, validates +// required AWS credentials and region, and constructs cloud.S3ObjectParameters. +type addURLInput struct { + s3URL string + path string + sha256 string + s3Params cloud.S3ObjectParameters } -func generatePointerFile(filePath string, sha256 string, fileSize int64) error { - // Define the pointer file content - pointerContent := fmt.Sprintf("version https://git-lfs.github.com/spec/v1\noid sha256:%s\nsize %d\n", sha256, fileSize) +func parseAddURLInput(cmd *cobra.Command, args []string) (addURLInput, error) { + s3URL := args[0] - // Ensure the directory exists - if err := os.MkdirAll(filepath.Dir(filePath), 0755); err != nil { - return fmt.Errorf("failed to create directory for pointer file: %w", err) + pathArg, err := resolvePathArg(s3URL, args) + if err != nil { + return addURLInput{}, err } - // Write the pointer file - if err := os.WriteFile(filePath, []byte(pointerContent), 0644); err != nil { - return fmt.Errorf("failed to write pointer file: %w", err) + sha256Param, err := cmd.Flags().GetString("sha256") + if err != nil { + return addURLInput{}, fmt.Errorf("read flag sha256: %w", err) } - // Add the pointer file to Git - cmd := exec.Command("git", "add", filePath) - if err := cmd.Run(); err != nil { - return fmt.Errorf("failed to add pointer file to Git: %w", err) + awsKey, err := cmd.Flags().GetString(cloud.AWS_KEY_FLAG_NAME) + if err != nil { + return addURLInput{}, fmt.Errorf("read flag %s: %w", cloud.AWS_KEY_FLAG_NAME, err) + } + awsSecret, err := cmd.Flags().GetString(cloud.AWS_SECRET_FLAG_NAME) + if err != nil { + return addURLInput{}, fmt.Errorf("read flag %s: %w", cloud.AWS_SECRET_FLAG_NAME, err) + } + awsRegion, err := cmd.Flags().GetString(cloud.AWS_REGION_FLAG_NAME) + if err != nil { + return addURLInput{}, fmt.Errorf("read flag %s: %w", cloud.AWS_REGION_FLAG_NAME, err) + } + awsEndpoint, err := cmd.Flags().GetString(cloud.AWS_ENDPOINT_URL_FLAG_NAME) + if err != nil { + return addURLInput{}, fmt.Errorf("read flag %s: %w", cloud.AWS_ENDPOINT_URL_FLAG_NAME, err) } + if awsKey == "" || awsSecret == "" { + return addURLInput{}, errors.New("AWS credentials must be provided via flags or environment variables") + } + if awsRegion == "" { + return addURLInput{}, errors.New("AWS region must be provided via flag or environment variable") + } + + s3Input := cloud.S3ObjectParameters{ + S3URL: s3URL, + AWSAccessKey: awsKey, + AWSSecretKey: awsSecret, + AWSRegion: awsRegion, + AWSEndpoint: awsEndpoint, + SHA256: sha256Param, + DestinationPath: pathArg, + } + + return addURLInput{ + s3URL: s3URL, + path: pathArg, + sha256: sha256Param, + s3Params: s3Input, + }, nil +} + +// resolvePathArg returns the explicit destination path argument when provided, +// otherwise derives the worktree path from the given S3 URL path component. +func resolvePathArg(s3URL string, args []string) (string, error) { + if len(args) == 2 { + return args[1], nil + } + u, err := url.Parse(s3URL) + if err != nil { + return "", err + } + return strings.TrimPrefix(u.Path, "/"), nil +} + +// printResolvedInfo writes a human-readable summary of resolved Git/LFS and +// S3 object information to the command's stdout for user confirmation. +func printResolvedInfo(cmd *cobra.Command, gitCommonDir, lfsRoot string, s3Info *cloud.S3Object, pathArg string, isTracked bool, sha256 string) error { + if _, err := fmt.Fprintf(cmd.OutOrStdout(), ` +Resolved Git LFS s3Info +--------------------- +Git common dir : %s +LFS storage : %s + +S3 object +--------- +Bucket : %s +Key : %s +Worktree name : %s +Size (bytes) : %d +SHA256 (meta) : %s +ETag : %s +Last modified : %s + +Worktree +------------- +path : %s +tracked by LFS : %v +sha256 param : %s + +`, + gitCommonDir, + lfsRoot, + s3Info.Bucket, + s3Info.Key, + s3Info.Path, + s3Info.SizeBytes, + s3Info.MetaSHA256, + s3Info.ETag, + s3Info.LastModTime.Format("2006-01-02T15:04:05Z07:00"), + pathArg, + isTracked, + sha256, + ); err != nil { + return fmt.Errorf("print resolved s3Info: %w", err) + } + return nil +} + +// ensureLFSObject ensures the LFS object identified by s3Info exists in the +// repository's LFS storage. If the input includes an explicit SHA256 that is +// returned immediately; otherwise the object is downloaded into a temporary +// file and moved into the LFS `objects` storage, returning the object's oid. +func (s *AddURLService) ensureLFSObject(ctx context.Context, s3Info *cloud.S3Object, input addURLInput, lfsRoot string) (string, error) { + if input.sha256 != "" { + return input.sha256, nil + } + + computedSHA, tmpObj, err := s.download(ctx, s3Info, input.s3Params, lfsRoot) + if err != nil { + return "", err + } + + oid := computedSHA + dstDir := filepath.Join(lfsRoot, "objects", oid[0:2], oid[2:4]) + dstObj := filepath.Join(dstDir, oid) + + if err := os.MkdirAll(dstDir, 0755); err != nil { + return "", fmt.Errorf("mkdir %s: %w", dstDir, err) + } + + if err := os.Rename(tmpObj, dstObj); err != nil { + return "", fmt.Errorf("rename %s to %s: %w", tmpObj, dstObj, err) + } + + if _, err := fmt.Fprintf(os.Stderr, "Added data file at %s\n", dstObj); err != nil { + return "", fmt.Errorf("stderr write: %w", err) + } + + return computedSHA, nil +} + +// writePointerFile writes a Git LFS pointer file at the given worktree path +// referencing the supplied oid and recording sizeBytes. It creates parent +// directories as needed and validates the path is non-empty. +func writePointerFile(pathArg, oid string, sizeBytes int64) error { + pointer := fmt.Sprintf( + "version https://git-lfs.github.com/spec/v1\noid sha256:%s\nsize %d\n", + oid, sizeBytes, + ) + if pathArg == "" { + return fmt.Errorf("empty worktree path") + } + safePath := filepath.Clean(pathArg) + dir := filepath.Dir(safePath) + if dir != "." && dir != "/" { + if err := os.MkdirAll(dir, 0755); err != nil { + return fmt.Errorf("mkdir %s: %w", dir, err) + } + } + if err := os.WriteFile(safePath, []byte(pointer), 0644); err != nil { + return fmt.Errorf("write %s: %w", safePath, err) + } + + if _, err := fmt.Fprintf(os.Stderr, "Added Git LFS pointer file at %s\n", safePath); err != nil { + return fmt.Errorf("stderr write: %w", err) + } + return nil +} + +// maybeTrackLFS ensures the supplied path is tracked by Git LFS by invoking +// the provided gitLFSTrack callback when the path is not already tracked. +// It reports the addition to stderr for user guidance. +func maybeTrackLFS(ctx context.Context, gitLFSTrack func(context.Context, string) (bool, error), pathArg string, isTracked bool) error { + if isTracked { + return nil + } + if _, err := gitLFSTrack(ctx, pathArg); err != nil { + return fmt.Errorf("git lfs track %s: %w", pathArg, err) + } + + if _, err := fmt.Fprintf(os.Stderr, "Info: Added to Git LFS. Remember to `git add %s` and `git commit ...`", pathArg); err != nil { + return fmt.Errorf("stderr write: %w", err) + } return nil } diff --git a/cmd/addurl/main_test.go b/cmd/addurl/main_test.go index 759c01dd..875dfca8 100644 --- a/cmd/addurl/main_test.go +++ b/cmd/addurl/main_test.go @@ -1,28 +1,334 @@ package addurl import ( + "bytes" + "context" + "crypto/sha256" + "encoding/json" + "fmt" + "io" + "log/slog" "os" + "os/exec" "path/filepath" + "strings" "testing" + "time" - "github.com/calypr/git-drs/internal/testutils" + "github.com/calypr/git-drs/cloud" + "github.com/calypr/git-drs/drsmap" + "github.com/calypr/git-drs/precommit_cache" + "github.com/spf13/cobra" ) -func TestGeneratePointerFile(t *testing.T) { - testutils.SetupTestGitRepo(t) +func TestRunAddURL_WritesPointerAndLFSObject(t *testing.T) { + content := "hello world" + sum := sha256.Sum256([]byte(content)) + shaHex := fmt.Sprintf("%x", sum[:]) - path := filepath.Join("data", "file.txt") - err := generatePointerFile(path, "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", 10) - if err != nil { - t.Fatalf("generatePointerFile error: %v", err) + tempDir := t.TempDir() + lfsRoot := filepath.Join(tempDir, ".git", "lfs") + + // ensure a git repository exists so any git-based config lookups succeed + cmdInit := exec.Command("git", "init") + cmdInit.Dir = tempDir + if out, err := cmdInit.CombinedOutput(); err != nil { + t.Fatalf("git init failed: %v: %s", err, out) + } + + // create a minimal drs config so runAddURL doesn't fail with + // "config file does not exist. Please run 'git drs init'..." + configPaths := []string{ + filepath.Join(tempDir, ".git", "drs", "config.yaml"), + } + for _, p := range configPaths { + // ensure parent dir exists for safety (e.g. .git should already exist from git init) + if dir := filepath.Dir(p); dir != tempDir && dir != "." { + _ = os.MkdirAll(dir, 0755) + } + yamlConfig := ` +default_remote: calypr-dev +remotes: + calypr-dev: + gen3: + endpoint: https://calypr-dev.ohsu.edu + project_id: cbds-monorepos + bucket: cbds +` + if err := os.WriteFile(p, []byte(yamlConfig), 0644); err != nil { + t.Fatalf("write config %s: %v", p, err) + } + fmt.Fprintf(os.Stderr, "TestRunAddURL_WritesPointerAndLFSObject wrote mock config file %s\n", p) } - content, err := os.ReadFile(path) + service := NewAddURLService() + resetStubs := stubAddURLDeps(t, service, + func(ctx context.Context, in cloud.S3ObjectParameters) (*cloud.S3Object, error) { + return &cloud.S3Object{ + Bucket: "bucket", + Key: "path/to/file.bin", + Path: "file.bin", + SizeBytes: int64(len(content)), + MetaSHA256: "", + ETag: "abcd1234", + LastModTime: time.Date(2024, 1, 2, 3, 4, 5, 0, time.UTC), + }, nil + }, + func(path string) (bool, error) { + return true, nil + }, + // download stub: write the LFS object into lfsRoot and return the sha + func(ctx context.Context, info *cloud.S3Object, input cloud.S3ObjectParameters, lfsRootPath string) (string, string, error) { + objPath := filepath.Join(lfsRootPath, "objects", shaHex[0:2], shaHex[2:4], shaHex) + if err := os.MkdirAll(filepath.Dir(objPath), 0755); err != nil { + return "", "", err + } + if err := os.WriteFile(objPath, []byte(content), 0644); err != nil { + return "", "", err + } + return shaHex, objPath, nil + }, + ) + t.Cleanup(resetStubs) + + cmd := NewCommand() + var out bytes.Buffer + cmd.SetOut(&out) + requireFlags(t, cmd) + + oldwd := mustChdir(t, tempDir) + t.Cleanup(func() { _ = os.Chdir(oldwd) }) + + if err := service.Run(cmd, []string{"s3://bucket/path/to/file.bin"}); err != nil { + t.Fatalf("service.Run error: %v", err) + } + + pointerPath := filepath.Join(tempDir, "path/to/file.bin") + pointerBytes, err := os.ReadFile(pointerPath) if err != nil { t.Fatalf("read pointer file: %v", err) } + expectedPointer := fmt.Sprintf( + "version https://git-lfs.github.com/spec/v1\noid sha256:%s\nsize %d\n", + shaHex, + len(content), + ) + if string(pointerBytes) != expectedPointer { + t.Fatalf("pointer mismatch: expected %q, got %q", expectedPointer, string(pointerBytes)) + } + + lfsObject := filepath.Join(lfsRoot, "objects", shaHex[0:2], shaHex[2:4], shaHex) + if _, err := os.Stat(lfsObject); err != nil { + t.Fatalf("expected LFS object at %s: %v", lfsObject, err) + } + + drsObject, err := drsmap.DrsInfoFromOid(shaHex) + if err != nil { + t.Fatalf("read drs object: %v", err) + } + if len(drsObject.AccessMethods) == 0 { + t.Fatalf("expected access methods in drs object") + } + if got := drsObject.AccessMethods[0].AccessURL.URL; got != "s3://bucket/path/to/file.bin" { + t.Fatalf("unexpected access URL: %s", got) + } +} + +func TestUpdatePrecommitCacheWritesEntries(t *testing.T) { + repo := setupGitRepo(t) + path := filepath.Join(repo, "data", "file.bin") + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + t.Fatalf("mkdir: %v", err) + } + if err := os.WriteFile(path, []byte("data"), 0o644); err != nil { + t.Fatalf("write: %v", err) + } + + oldwd := mustChdir(t, repo) + t.Cleanup(func() { _ = os.Chdir(oldwd) }) + + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + oid := "sha256deadbeef" + externalURL := "s3://bucket/data/file.bin" + + if err := updatePrecommitCache(context.Background(), logger, path, oid, externalURL); err != nil { + t.Fatalf("updatePrecommitCache: %v", err) + } + + cacheRoot := filepath.Join(repo, ".git", "drs", "pre-commit", "v1") + pathsDir := filepath.Join(cacheRoot, "paths") + oidDir := filepath.Join(cacheRoot, "oids") - if len(content) == 0 { - t.Fatalf("expected pointer file content") + pathEntryFile := filepath.Join(pathsDir, precommit_cache.EncodePath("data/file.bin")+".json") + pathData, err := os.ReadFile(pathEntryFile) + if err != nil { + t.Fatalf("read path entry: %v", err) + } + var pathEntry precommit_cache.PathEntry + if err := json.Unmarshal(pathData, &pathEntry); err != nil { + t.Fatalf("unmarshal path entry: %v", err) + } + if pathEntry.Path != "data/file.bin" { + t.Fatalf("expected path entry path to be %q, got %q", "data/file.bin", pathEntry.Path) + } + if pathEntry.LFSOID != oid { + t.Fatalf("expected path entry oid to be %q, got %q", oid, pathEntry.LFSOID) + } + if pathEntry.UpdatedAt == "" { + t.Fatalf("expected updated_at to be set") + } + + oidSum := sha256.Sum256([]byte(oid)) + oidEntryFile := filepath.Join(oidDir, fmt.Sprintf("%x.json", oidSum[:])) + oidData, err := os.ReadFile(oidEntryFile) + if err != nil { + t.Fatalf("read oid entry: %v", err) + } + var oidEntry precommit_cache.OIDEntry + if err := json.Unmarshal(oidData, &oidEntry); err != nil { + t.Fatalf("unmarshal oid entry: %v", err) + } + if oidEntry.LFSOID != oid { + t.Fatalf("expected oid entry oid to be %q, got %q", oid, oidEntry.LFSOID) + } + if oidEntry.ExternalURL != externalURL { + t.Fatalf("expected oid entry external_url to be %q, got %q", externalURL, oidEntry.ExternalURL) + } + if len(oidEntry.Paths) != 1 || oidEntry.Paths[0] != "data/file.bin" { + t.Fatalf("expected oid entry paths to include data/file.bin, got %v", oidEntry.Paths) + } + if oidEntry.UpdatedAt == "" { + t.Fatalf("expected oid entry updated_at to be set") + } +} + +func TestUpdatePrecommitCacheContentChanged(t *testing.T) { + repo := setupGitRepo(t) + path := filepath.Join(repo, "data", "file.bin") + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + t.Fatalf("mkdir: %v", err) + } + if err := os.WriteFile(path, []byte("data"), 0o644); err != nil { + t.Fatalf("write: %v", err) + } + + oldwd := mustChdir(t, repo) + t.Cleanup(func() { _ = os.Chdir(oldwd) }) + + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + firstOID := "sha256first" + secondOID := "sha256second" + + if err := updatePrecommitCache(context.Background(), logger, path, firstOID, "s3://bucket/first"); err != nil { + t.Fatalf("updatePrecommitCache first: %v", err) + } + if err := updatePrecommitCache(context.Background(), logger, path, secondOID, "s3://bucket/second"); err != nil { + t.Fatalf("updatePrecommitCache second: %v", err) + } + + cacheRoot := filepath.Join(repo, ".git", "drs", "pre-commit", "v1") + oidDir := filepath.Join(cacheRoot, "oids") + + firstSum := sha256.Sum256([]byte(firstOID)) + firstEntryFile := filepath.Join(oidDir, fmt.Sprintf("%x.json", firstSum[:])) + firstData, err := os.ReadFile(firstEntryFile) + if err != nil { + t.Fatalf("read first oid entry: %v", err) + } + var firstEntry precommit_cache.OIDEntry + if err := json.Unmarshal(firstData, &firstEntry); err != nil { + t.Fatalf("unmarshal first oid entry: %v", err) + } + if len(firstEntry.Paths) != 0 { + t.Fatalf("expected old oid entry paths to be empty, got %v", firstEntry.Paths) + } + + secondSum := sha256.Sum256([]byte(secondOID)) + secondEntryFile := filepath.Join(oidDir, fmt.Sprintf("%x.json", secondSum[:])) + secondData, err := os.ReadFile(secondEntryFile) + if err != nil { + t.Fatalf("read second oid entry: %v", err) + } + var secondEntry precommit_cache.OIDEntry + if err := json.Unmarshal(secondData, &secondEntry); err != nil { + t.Fatalf("unmarshal second oid entry: %v", err) + } + if !secondEntry.ContentChange { + t.Fatalf("expected content_changed to be true") + } + if len(secondEntry.Paths) != 1 || secondEntry.Paths[0] != "data/file.bin" { + t.Fatalf("expected new oid entry paths to include data/file.bin, got %v", secondEntry.Paths) + } +} + +// deprecated test case: now that we always "trust" the client-provided SHA256, this case is not applicable +//func TestRunAddURL_SHA256Mismatch(t *testing.T) { +// ... +//} + +func stubAddURLDeps( + t *testing.T, + service *AddURLService, + inspectFn func(context.Context, cloud.S3ObjectParameters) (*cloud.S3Object, error), + isTrackedFn func(string) (bool, error), + downloadFn func(context.Context, *cloud.S3Object, cloud.S3ObjectParameters, string) (string, string, error), +) func() { + t.Helper() + origInspect := service.inspectS3 + origIsTracked := service.isLFSTracked + origDownload := service.download + + service.inspectS3 = inspectFn + service.isLFSTracked = isTrackedFn + service.download = downloadFn + + return func() { + service.inspectS3 = origInspect + service.isLFSTracked = origIsTracked + service.download = origDownload + } +} + +func requireFlags(t *testing.T, cmd *cobra.Command) { + t.Helper() + if err := cmd.Flags().Set(cloud.AWS_KEY_FLAG_NAME, "key"); err != nil { + t.Fatalf("set aws key: %v", err) + } + if err := cmd.Flags().Set(cloud.AWS_SECRET_FLAG_NAME, "secret"); err != nil { + t.Fatalf("set aws secret: %v", err) + } + if err := cmd.Flags().Set(cloud.AWS_REGION_FLAG_NAME, "region"); err != nil { + t.Fatalf("set aws region: %v", err) + } +} + +func mustChdir(t *testing.T, dir string) string { + t.Helper() + old, err := os.Getwd() + if err != nil { + t.Fatalf("Getwd: %v", err) + } + if err := os.Chdir(dir); err != nil { + t.Fatalf("Chdir(%s): %v", dir, err) + } + return old +} + +func setupGitRepo(t *testing.T) string { + t.Helper() + dir := t.TempDir() + gitCmd(t, dir, "init") + gitCmd(t, dir, "config", "user.email", "test@example.com") + gitCmd(t, dir, "config", "user.name", "Test User") + return dir +} + +func gitCmd(t *testing.T, dir string, args ...string) { + t.Helper() + cmd := exec.Command("git", args...) + cmd.Dir = dir + out, err := cmd.CombinedOutput() + if err != nil { + t.Fatalf("git %s failed: %v (%s)", strings.Join(args, " "), err, string(out)) } } diff --git a/cmd/download/main.go b/cmd/download/main.go index 2a33a020..90da849a 100644 --- a/cmd/download/main.go +++ b/cmd/download/main.go @@ -3,11 +3,11 @@ package download import ( "fmt" + "github.com/calypr/git-drs/cloud" "github.com/calypr/git-drs/config" "github.com/calypr/git-drs/drslog" "github.com/calypr/git-drs/drsmap" "github.com/calypr/git-drs/projectdir" - "github.com/calypr/git-drs/s3_utils" "github.com/spf13/cobra" ) @@ -66,7 +66,7 @@ var Cmd = &cobra.Command{ if err != nil { return fmt.Errorf("Error getting destination path for OID %s: %v", oid, err) } - err = s3_utils.DownloadSignedUrl(accessUrl.URL, dstPath) + err = cloud.DownloadSignedUrl(accessUrl.URL, dstPath) if err != nil { return fmt.Errorf("Error downloading file for OID %s: %v", oid, err) } diff --git a/cmd/initialize/main.go b/cmd/initialize/main.go index fc4c9711..a03e67a5 100644 --- a/cmd/initialize/main.go +++ b/cmd/initialize/main.go @@ -65,6 +65,11 @@ var Cmd = &cobra.Command{ if err != nil { return fmt.Errorf("error installing pre-push hook: %v", err) } + // install pre-commit hook + err = installPreCommitHook(logg) + if err != nil { + return fmt.Errorf("error installing pre-commit hook: %v", err) + } // final logs logg.Debug("Git DRS initialized") @@ -155,3 +160,49 @@ exec git lfs pre-push "$remote" "$url" < "$TMPFILE" logger.Debug("pre-push hook installed") return nil } + +func installPreCommitHook(logger *slog.Logger) error { + cmd := exec.Command("git", "rev-parse", "--git-dir") + cmdOut, err := cmd.Output() + if err != nil { + return fmt.Errorf("unable to locate git directory: %w", err) + } + gitDir := strings.TrimSpace(string(cmdOut)) + hooksDir := filepath.Join(gitDir, "hooks") + if err := os.MkdirAll(hooksDir, 0755); err != nil { + return fmt.Errorf("unable to create hooks directory: %w", err) + } + + hookPath := filepath.Join(hooksDir, "pre-commit") + hookBody := ` +# .git/hooks/pre-commit +exec git drs precommit +` + hookScript := "#!/bin/sh\n" + hookBody + + existingContent, err := os.ReadFile(hookPath) + if err == nil { + // there is an existing hook, rename it, and let the user know + // Backup existing hook with timestamp + timestamp := time.Now().Format("20060102T150405") + backupPath := hookPath + "." + timestamp + if err := os.WriteFile(backupPath, existingContent, 0644); err != nil { + return fmt.Errorf("unable to back up existing pre-commit hook: %w", err) + } + if err := os.Remove(hookPath); err != nil { + return fmt.Errorf("unable to remove hook after backing up: %w", err) + } + logger.Debug(fmt.Sprintf("pre-commit hook updated; backup written to %s", backupPath)) + } + // If there was an error other than expected not existing, return it + if err != nil && !os.IsNotExist(err) { + return fmt.Errorf("unable to read pre-commit hook: %w", err) + } + + err = os.WriteFile(hookPath, []byte(hookScript), 0755) + if err != nil { + return fmt.Errorf("unable to write pre-commit hook: %w", err) + } + logger.Debug("pre-commit hook installed") + return nil +} diff --git a/cmd/initialize/main_test.go b/cmd/initialize/main_test.go index 0fae869b..6ed54d46 100644 --- a/cmd/initialize/main_test.go +++ b/cmd/initialize/main_test.go @@ -32,6 +32,28 @@ func TestInstallPrePushHook(t *testing.T) { } } +func TestInstallPreCommitHook(t *testing.T) { + testutils.SetupTestGitRepo(t) + logger := drslog.NewNoOpLogger() + + if err := installPreCommitHook(logger); err != nil { + t.Fatalf("installPreCommitHook error: %v", err) + } + + hookPath := filepath.Join(".git", "hooks", "pre-commit") + content, err := os.ReadFile(hookPath) + if err != nil { + t.Fatalf("read hook: %v", err) + } + if !strings.Contains(string(content), "git drs precommit") { + t.Fatalf("expected hook to contain git drs precommit") + } + + if err := installPreCommitHook(logger); err != nil { + t.Fatalf("installPreCommitHook second call error: %v", err) + } +} + func TestInitGitConfig(t *testing.T) { testutils.SetupTestGitRepo(t) transfers = 2 diff --git a/cmd/precommit/main.go b/cmd/precommit/main.go new file mode 100644 index 00000000..ac2f3047 --- /dev/null +++ b/cmd/precommit/main.go @@ -0,0 +1,547 @@ +// Package precommit +// ------------------------------------- +// LFS-only local cache updater for: +// - Path -> OID : .git/drs/pre-commit/v1/paths/.json +// - OID -> Paths + S3 URL hint : .git/drs/pre-commit/v1/oids/.json +// +// This hook is intentionally: +// - LFS-only (non-LFS paths are ignored) +// - local-only (no network, no server index reads) +// - index-based (reads STAGED content via `git show :`) +// +// Note: This is a reference implementation. Adjust logging/policy as desired. +package precommit + +import ( + "bufio" + "bytes" + "context" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "os" + "os/exec" + "path/filepath" + "sort" + "strings" + "time" + + "github.com/spf13/cobra" +) + +const ( + cacheVersionDir = "drs/pre-commit/v1" + lfsSpecLine = "version https://git-lfs.github.com/spec/v1" +) + +type PathEntry struct { + Path string `json:"path"` + LFSOID string `json:"lfs_oid"` + UpdatedAt string `json:"updated_at"` +} + +type OIDEntry struct { + LFSOID string `json:"lfs_oid"` + Paths []string `json:"paths"` + S3URL string `json:"s3_url,omitempty"` // hint only; may be empty + UpdatedAt string `json:"updated_at"` + ContentChange bool `json:"content_changed"` +} + +type ChangeKind int + +const ( + KindAdd ChangeKind = iota + KindModify + KindDelete + KindRename +) + +type Change struct { + Kind ChangeKind + OldPath string // for rename + NewPath string // for rename (and for add/modify/delete uses NewPath) + Status string // raw status, e.g. "A", "M", "D", "R100" +} + +// Cmd line declaration +var Cmd = &cobra.Command{ + Use: "precommit", + Short: "pre-commit hook to update local DRS cache", + Long: "Pre-commit hook that updates the local DRS pre-commit cache", + Args: cobra.ExactArgs(0), + RunE: func(cmd *cobra.Command, args []string) error { + return run(context.Background()) + }, +} + +func main() { + ctx := context.Background() + if err := run(ctx); err != nil { + // For a reference impl, treat errors as non-fatal unless you want strict enforcement. + // Exiting non-zero blocks the commit. + fmt.Fprintf(os.Stderr, "pre-commit drs cache: %v\n", err) + os.Exit(1) + } +} + +func run(ctx context.Context) error { + gitDir, err := gitRevParseGitDir(ctx) + if err != nil { + return err + } + + cacheRoot := filepath.Join(gitDir, cacheVersionDir) + pathsDir := filepath.Join(cacheRoot, "paths") + oidsDir := filepath.Join(cacheRoot, "oids") + tombsDir := filepath.Join(cacheRoot, "tombstones") + + if err := os.MkdirAll(pathsDir, 0o755); err != nil { + return err + } + if err := os.MkdirAll(oidsDir, 0o755); err != nil { + return err + } + _ = os.MkdirAll(tombsDir, 0o755) // optional + + changes, err := stagedChanges(ctx) + if err != nil { + return err + } + if len(changes) == 0 { + return nil + } + + now := time.Now().UTC().Format(time.RFC3339) + + // Process renames first so subsequent add/modify logic sees the "new" path. + // This mirrors how we want cache paths to follow staged paths. + for _, ch := range changes { + if ch.Kind != KindRename { + continue + } + // Only act if BOTH old and new are LFS in scope? Prefer: + // - If the new path is LFS, we migrate. + // - If it isn't LFS, we remove old path entry (out of scope). + newOID, newIsLFS, err := stagedLFSOID(ctx, ch.NewPath) + if err != nil { + // If file doesn't exist in index due to weird staging, skip. + continue + } + + oldPathFile := pathEntryFile(pathsDir, ch.OldPath) + newPathFile := pathEntryFile(pathsDir, ch.NewPath) + + if newIsLFS { + // Move/overwrite path entry file + if err := moveFileBestEffort(oldPathFile, newPathFile); err != nil && !errors.Is(err, os.ErrNotExist) { + return fmt.Errorf("rename migrate path entry: %w", err) + } + + // Ensure path entry content correct + if err := writeJSONAtomic(newPathFile, PathEntry{ + Path: ch.NewPath, + LFSOID: newOID, + UpdatedAt: now, + }); err != nil { + return err + } + + // Update oid entry: replace old path with new path for that OID + if err := oidAddOrReplacePath(oidsDir, newOID, ch.OldPath, ch.NewPath, now, false); err != nil { + return err + } + } else { + // Out of scope now: remove any cached path entry. + _ = os.Remove(oldPathFile) + } + } + + // Process adds/modifies/deletes (and renames again just to ensure content correctness on new path). + for _, ch := range changes { + switch ch.Kind { + case KindAdd, KindModify: + if err := handleUpsert(ctx, pathsDir, oidsDir, ch.NewPath, now); err != nil { + return err + } + case KindRename: + // Treat like upsert on NewPath to ensure OID/path consistency if content also changed. + if err := handleUpsert(ctx, pathsDir, oidsDir, ch.NewPath, now); err != nil { + return err + } + // Optionally also remove old path from *other* OID entry if rename+content-change changed OID. + // We'll do it inside handleUpsert by checking previous cached OID for that path (after move). + case KindDelete: + if err := handleDelete(ctx, pathsDir, oidsDir, tombsDir, ch.NewPath, now); err != nil { + return err + } + } + } + + return nil +} + +func handleUpsert(ctx context.Context, pathsDir, oidsDir, path, now string) error { + oid, isLFS, err := stagedLFSOID(ctx, path) + if err != nil { + // If file isn't in index, ignore. + return nil + } + if !isLFS { + // Out of scope. + return nil + } + + pathFile := pathEntryFile(pathsDir, path) + + // Load previous path entry if it exists to detect content changes. + var prev PathEntry + prevExists := false + if b, err := os.ReadFile(pathFile); err == nil { + _ = json.Unmarshal(b, &prev) + if prev.Path != "" && prev.LFSOID != "" { + prevExists = true + } + } + + // Write/update path entry. + if err := writeJSONAtomic(pathFile, PathEntry{ + Path: path, + LFSOID: oid, + UpdatedAt: now, + }); err != nil { + return err + } + + // Update OID entry for new oid: add path. + contentChanged := prevExists && prev.LFSOID != oid + if err := oidAddOrReplacePath(oidsDir, oid, "", path, now, contentChanged); err != nil { + return err + } + + // If content changed, remove path from the *old* oid entry (best effort). + if contentChanged { + _ = oidRemovePath(oidsDir, prev.LFSOID, path, now) + } + + return nil +} + +func handleDelete(ctx context.Context, pathsDir, oidsDir, tombsDir, path, now string) error { + // Only consider deletion if it was previously an LFS entry (cache-driven). + pathFile := pathEntryFile(pathsDir, path) + b, err := os.ReadFile(pathFile) + if err != nil { + // nothing to do + return nil + } + var pe PathEntry + if err := json.Unmarshal(b, &pe); err != nil { + // corrupted cache; remove it + _ = os.Remove(pathFile) + return nil + } + // Remove path entry. + _ = os.Remove(pathFile) + + // Remove this path from the old oid entry (best effort). + if pe.LFSOID != "" { + _ = oidRemovePath(oidsDir, pe.LFSOID, path, now) + } + + // Optional tombstone. + tombFile := filepath.Join(tombsDir, encodePath(path)+".json") + _ = writeJSONAtomic(tombFile, map[string]string{ + "path": path, + "deleted_at": now, + }) + + return nil +} + +// stagedChanges parses: git diff --cached --name-status -M +// Formats: +// +// Apath +// Mpath +// Dpath +// R100oldnew +func stagedChanges(ctx context.Context) ([]Change, error) { + out, err := git(ctx, "diff", "--cached", "--name-status", "-M") + if err != nil { + return nil, err + } + var changes []Change + sc := bufio.NewScanner(bytes.NewReader(out)) + for sc.Scan() { + line := sc.Text() + if strings.TrimSpace(line) == "" { + continue + } + parts := strings.Split(line, "\t") + if len(parts) < 2 { + continue + } + status := parts[0] + switch { + case status == "A": + changes = append(changes, Change{Kind: KindAdd, NewPath: parts[1], Status: status}) + case status == "M": + changes = append(changes, Change{Kind: KindModify, NewPath: parts[1], Status: status}) + case status == "D": + changes = append(changes, Change{Kind: KindDelete, NewPath: parts[1], Status: status}) + case strings.HasPrefix(status, "R") && len(parts) >= 3: + changes = append(changes, Change{Kind: KindRename, OldPath: parts[1], NewPath: parts[2], Status: status}) + default: + // ignore other statuses (C, T, U, etc) for this reference impl + } + } + if err := sc.Err(); err != nil { + return nil, err + } + return changes, nil +} + +// stagedLFSOID returns (oid, isLFS, err) based on STAGED content. +// isLFS is true only if the staged file is a valid LFS pointer with an oid sha256 line. +func stagedLFSOID(ctx context.Context, path string) (string, bool, error) { + out, err := git(ctx, "show", ":"+path) + if err != nil { + // path may not exist in index (deleted/intent-to-add weirdness) + return "", false, err + } + + // Fast parse: look for spec line and oid line near top. + // LFS pointer files are small; scanning full content is fine. + var hasSpec bool + var oid string + + sc := bufio.NewScanner(bytes.NewReader(out)) + for sc.Scan() { + line := sc.Text() + if line == lfsSpecLine { + hasSpec = true + continue + } + if strings.HasPrefix(line, "oid sha256:") { + hex := strings.TrimPrefix(line, "oid sha256:") + hex = strings.TrimSpace(hex) + if hex != "" { + oid = "sha256:" + hex + } + // keep scanning a bit in case spec is below (rare), but we can break once both are found. + } + // pointer usually has only a few lines; stop early after 10 lines + if hasSpec && oid != "" { + break + } + } + if err := sc.Err(); err != nil { + return "", false, err + } + + if hasSpec && oid != "" { + return oid, true, nil + } + return "", false, nil +} + +func gitRevParseGitDir(ctx context.Context) (string, error) { + out, err := git(ctx, "rev-parse", "--git-dir") + if err != nil { + return "", err + } + gitDir := strings.TrimSpace(string(out)) + if gitDir == "" { + return "", errors.New("could not determine .git dir") + } + // If gitDir is relative, resolve relative to repo root + if !filepath.IsAbs(gitDir) { + rootOut, err := git(ctx, "rev-parse", "--show-toplevel") + if err != nil { + return "", err + } + root := strings.TrimSpace(string(rootOut)) + gitDir = filepath.Join(root, gitDir) + } + return gitDir, nil +} + +func git(ctx context.Context, args ...string) ([]byte, error) { + cmd := exec.CommandContext(ctx, "git", args...) + cmd.Env = os.Environ() + var stdout, stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + if err := cmd.Run(); err != nil { + // include stderr for debugging; don’t leak massive output + msg := strings.TrimSpace(stderr.String()) + if msg == "" { + msg = err.Error() + } + return nil, fmt.Errorf("git %s: %s", strings.Join(args, " "), msg) + } + return stdout.Bytes(), nil +} + +// pathEntryFile maps a repo-relative path to a cache file location. +// We keep a deterministic encoding so any path maps to exactly one file. +func pathEntryFile(pathsDir, path string) string { + return filepath.Join(pathsDir, encodePath(path)+".json") +} + +func encodePath(path string) string { + // base64url encoding of the UTF-8 path string (no padding) is simple and safe. + return base64.RawURLEncoding.EncodeToString([]byte(path)) +} + +func oidEntryFile(oidsDir, oid string) string { + // OID contains ":"; make it filesystem safe but still human readable. + // Use a stable transform; here: sha256 of oid string to avoid path length issues. + sum := sha256.Sum256([]byte(oid)) + return filepath.Join(oidsDir, fmt.Sprintf("%x.json", sum[:])) +} + +// oidAddOrReplacePath: +// - loads oid entry (if exists) +// - adds newPath to paths[] +// - if oldPath != "" and present, replaces it with newPath +// - sets ContentChange flag if requested (ORed into existing flag) +// - preserves existing s3_url hint +func oidAddOrReplacePath(oidsDir, oid, oldPath, newPath, now string, contentChanged bool) error { + f := oidEntryFile(oidsDir, oid) + + entry := OIDEntry{ + LFSOID: oid, + Paths: []string{}, + UpdatedAt: now, + } + if b, err := os.ReadFile(f); err == nil { + _ = json.Unmarshal(b, &entry) + // ensure oid is set even if old file was incomplete + entry.LFSOID = oid + } + + paths := make(map[string]struct{}, len(entry.Paths)+1) + for _, p := range entry.Paths { + paths[p] = struct{}{} + } + + if oldPath != "" { + delete(paths, oldPath) + } + if newPath != "" { + paths[newPath] = struct{}{} + } + + entry.Paths = keysSorted(paths) + entry.UpdatedAt = now + entry.ContentChange = entry.ContentChange || contentChanged + + return writeJSONAtomic(f, entry) +} + +func oidRemovePath(oidsDir, oid, path, now string) error { + f := oidEntryFile(oidsDir, oid) + + b, err := os.ReadFile(f) + if err != nil { + return err + } + var entry OIDEntry + if err := json.Unmarshal(b, &entry); err != nil { + return err + } + paths := make(map[string]struct{}, len(entry.Paths)) + for _, p := range entry.Paths { + if p == path { + continue + } + paths[p] = struct{}{} + } + entry.Paths = keysSorted(paths) + entry.UpdatedAt = now + + // If no paths remain, keep the file (it may still hold s3_url hint) or delete it. + // This ADR allows stale entries; keeping is fine. Optionally delete when empty: + // if len(entry.Paths) == 0 && entry.S3URL == "" { return os.Remove(f) } + + return writeJSONAtomic(f, entry) +} + +func keysSorted(m map[string]struct{}) []string { + out := make([]string, 0, len(m)) + for k := range m { + out = append(out, k) + } + sort.Strings(out) + return out +} + +// writeJSONAtomic writes JSON to a temp file then renames it into place. +// This avoids partially written cache files if the process is interrupted. +func writeJSONAtomic(path string, v any) error { + dir := filepath.Dir(path) + if err := os.MkdirAll(dir, 0o755); err != nil { + return err + } + + tmp := path + ".tmp" + f, err := os.OpenFile(tmp, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0o644) + if err != nil { + return err + } + defer func() { _ = f.Close() }() + + enc := json.NewEncoder(f) + enc.SetIndent("", " ") + if err := enc.Encode(v); err != nil { + _ = os.Remove(tmp) + return err + } + if err := f.Sync(); err != nil { + _ = os.Remove(tmp) + return err + } + if err := f.Close(); err != nil { + _ = os.Remove(tmp) + return err + } + return os.Rename(tmp, path) +} + +func moveFileBestEffort(src, dst string) error { + // Ensure destination directory exists. + if err := os.MkdirAll(filepath.Dir(dst), 0o755); err != nil { + return err + } + // Rename will fail across devices; fall back to copy+remove. + if err := os.Rename(src, dst); err == nil { + return nil + } else if errors.Is(err, os.ErrNotExist) { + return err + } + + in, err := os.Open(src) + if err != nil { + return err + } + defer in.Close() + + out, err := os.OpenFile(dst, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0o644) + if err != nil { + return err + } + + if _, err := io.Copy(out, in); err != nil { + _ = out.Close() + return err + } + if err := out.Close(); err != nil { + return err + } + return os.Remove(src) +} diff --git a/cmd/precommit/main_test.go b/cmd/precommit/main_test.go new file mode 100644 index 00000000..8a0fb0c6 --- /dev/null +++ b/cmd/precommit/main_test.go @@ -0,0 +1,146 @@ +package precommit + +import ( + "context" + "encoding/json" + "os" + "os/exec" + "path/filepath" + "strings" + "testing" + "time" +) + +func TestHandleUpsertIgnoresNonLFSFile(t *testing.T) { + repo := setupGitRepo(t) + oldwd := mustChdir(t, repo) + t.Cleanup(func() { _ = os.Chdir(oldwd) }) + + path := filepath.Join(repo, "data", "file.txt") + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + t.Fatalf("mkdir: %v", err) + } + if err := os.WriteFile(path, []byte("plain content"), 0o644); err != nil { + t.Fatalf("write: %v", err) + } + gitCmd(t, repo, "add", "data/file.txt") + + cacheRoot := filepath.Join(repo, ".git", "drs", "pre-commit", "v1") + pathsDir := filepath.Join(cacheRoot, "paths") + oidsDir := filepath.Join(cacheRoot, "oids") + if err := os.MkdirAll(pathsDir, 0o755); err != nil { + t.Fatalf("mkdir paths: %v", err) + } + if err := os.MkdirAll(oidsDir, 0o755); err != nil { + t.Fatalf("mkdir oids: %v", err) + } + + now := time.Now().UTC().Format(time.RFC3339) + if err := handleUpsert(context.Background(), pathsDir, oidsDir, "data/file.txt", now); err != nil { + t.Fatalf("handleUpsert: %v", err) + } + + pathEntry := pathEntryFile(pathsDir, "data/file.txt") + if _, err := os.Stat(pathEntry); !os.IsNotExist(err) { + t.Fatalf("expected no cache entry for non-LFS file, got err=%v", err) + } +} + +func TestHandleUpsertWritesLFSPointerCache(t *testing.T) { + repo := setupGitRepo(t) + oldwd := mustChdir(t, repo) + t.Cleanup(func() { _ = os.Chdir(oldwd) }) + + path := filepath.Join(repo, "data", "file.bin") + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + t.Fatalf("mkdir: %v", err) + } + lfsPointer := strings.Join([]string{ + "version https://git-lfs.github.com/spec/v1", + "oid sha256:deadbeef", + "size 12", + "", + }, "\n") + if err := os.WriteFile(path, []byte(lfsPointer), 0o644); err != nil { + t.Fatalf("write: %v", err) + } + gitCmd(t, repo, "add", "data/file.bin") + + cacheRoot := filepath.Join(repo, ".git", "drs", "pre-commit", "v1") + pathsDir := filepath.Join(cacheRoot, "paths") + oidsDir := filepath.Join(cacheRoot, "oids") + if err := os.MkdirAll(pathsDir, 0o755); err != nil { + t.Fatalf("mkdir paths: %v", err) + } + if err := os.MkdirAll(oidsDir, 0o755); err != nil { + t.Fatalf("mkdir oids: %v", err) + } + + now := time.Now().UTC().Format(time.RFC3339) + if err := handleUpsert(context.Background(), pathsDir, oidsDir, "data/file.bin", now); err != nil { + t.Fatalf("handleUpsert: %v", err) + } + + pathEntry := pathEntryFile(pathsDir, "data/file.bin") + pathData, err := os.ReadFile(pathEntry) + if err != nil { + t.Fatalf("read path entry: %v", err) + } + var pathCache PathEntry + if err := json.Unmarshal(pathData, &pathCache); err != nil { + t.Fatalf("unmarshal path entry: %v", err) + } + if pathCache.Path != "data/file.bin" { + t.Fatalf("expected path entry to be data/file.bin, got %q", pathCache.Path) + } + if pathCache.LFSOID != "sha256:deadbeef" { + t.Fatalf("expected lfs oid sha256:deadbeef, got %q", pathCache.LFSOID) + } + + oidEntry := oidEntryFile(oidsDir, "sha256:deadbeef") + oidData, err := os.ReadFile(oidEntry) + if err != nil { + t.Fatalf("read oid entry: %v", err) + } + var oidCache OIDEntry + if err := json.Unmarshal(oidData, &oidCache); err != nil { + t.Fatalf("unmarshal oid entry: %v", err) + } + if oidCache.LFSOID != "sha256:deadbeef" { + t.Fatalf("expected oid entry sha256:deadbeef, got %q", oidCache.LFSOID) + } + if len(oidCache.Paths) != 1 || oidCache.Paths[0] != "data/file.bin" { + t.Fatalf("expected oid paths to include data/file.bin, got %v", oidCache.Paths) + } +} + +func setupGitRepo(t *testing.T) string { + t.Helper() + dir := t.TempDir() + gitCmd(t, dir, "init") + gitCmd(t, dir, "config", "user.email", "test@example.com") + gitCmd(t, dir, "config", "user.name", "Test User") + return dir +} + +func mustChdir(t *testing.T, dir string) string { + t.Helper() + old, err := os.Getwd() + if err != nil { + t.Fatalf("Getwd: %v", err) + } + if err := os.Chdir(dir); err != nil { + t.Fatalf("Chdir(%s): %v", dir, err) + } + return old +} + +func gitCmd(t *testing.T, dir string, args ...string) { + t.Helper() + cmd := exec.Command("git", args...) + cmd.Dir = dir + out, err := cmd.CombinedOutput() + if err != nil { + t.Fatalf("git %s failed: %v (%s)", strings.Join(args, " "), err, string(out)) + } +} diff --git a/cmd/prepush/main.go b/cmd/prepush/main.go index 235377c8..751adbb0 100644 --- a/cmd/prepush/main.go +++ b/cmd/prepush/main.go @@ -2,16 +2,22 @@ package prepush import ( "bufio" + "context" "fmt" "io" + "log/slog" "os" + "os/exec" "sort" "strings" + "time" - indexd_client "github.com/calypr/git-drs/client/indexd" "github.com/calypr/git-drs/config" + "github.com/calypr/git-drs/drs" "github.com/calypr/git-drs/drslog" "github.com/calypr/git-drs/drsmap" + drslfs "github.com/calypr/git-drs/drsmap/lfs" + "github.com/calypr/git-drs/precommit_cache" "github.com/spf13/cobra" ) @@ -22,108 +28,135 @@ var Cmd = &cobra.Command{ Long: "Pre-push hook that updates DRS objects before transfer", Args: cobra.RangeArgs(0, 2), RunE: func(cmd *cobra.Command, args []string) error { - //myLogger := drslog.GetLogger() - myLogger, err := drslog.NewLogger("", false) - if err != nil { - return fmt.Errorf("error creating logger: %v", err) - } + return NewPrePushService().Run(args, os.Stdin) + }, +} - myLogger.Debug("~~~~~~~~~~~~~ START: pre-push ~~~~~~~~~~~~~") +type PrePushService struct { + newLogger func(string, bool) (*slog.Logger, error) + loadConfig func() (*config.Config, error) + updateDrsObjects func(drs.ObjectBuilder, map[string]drslfs.LfsFileInfo, *precommit_cache.Cache, bool, *slog.Logger) error + createTempFile func(dir, pattern string) (*os.File, error) +} - cfg, err := config.LoadConfig() - if err != nil { - return fmt.Errorf("error getting config: %v", err) - } +func NewPrePushService() *PrePushService { + return &PrePushService{ + newLogger: drslog.NewLogger, + loadConfig: config.LoadConfig, + updateDrsObjects: drsmap.UpdateDrsObjectsWithFiles, + createTempFile: os.CreateTemp, + } +} - //Command-line arguments: The hook receives two parameters: - //* The name of the remote (e.g., origin). - //* The remote's location/URL (e.g., github.com). - // Create gitRemoteName and gitRemoteLocation from args. - myLogger.Debug(fmt.Sprintf("pre-push args: %v", args)) - var gitRemoteName, gitRemoteLocation string - if len(args) >= 1 { - gitRemoteName = args[0] - } - if len(args) >= 2 { - gitRemoteLocation = args[1] - } - if gitRemoteName == "" { - gitRemoteName = "origin" - } - myLogger.Debug(fmt.Sprintf("git remote name: %s, git remote location: %s", gitRemoteName, gitRemoteLocation)) +func (s *PrePushService) Run(args []string, stdin io.Reader) error { + ctx := context.Background() + myLogger, err := s.newLogger("", false) + if err != nil { + return fmt.Errorf("error creating logger: %v", err) + } - // get the default remote from the .git/drs/config - var remote config.Remote - remote, err = cfg.GetDefaultRemote() - if err != nil { - myLogger.Debug(fmt.Sprintf("Warning. Error getting default remote: %v", err)) - // Print warning to stderr and return success (exit 0) - fmt.Fprintln(os.Stderr, "Warning. Skipping DRS preparation. Error getting default remote:", err) - return nil - } + myLogger.Info("~~~~~~~~~~~~~ START: pre-push ~~~~~~~~~~~~~") - // get the remote client - cli, err := cfg.GetRemoteClient(remote, myLogger) - if err != nil { - // Print warning to stderr and return success (exit 0) - fmt.Fprintln(os.Stderr, "Warning. Skipping DRS preparation. Error getting remote client:", err) - myLogger.Debug(fmt.Sprintf("Warning. Skipping DRS preparation. Error getting remote client: %v", err)) - return nil - } + cfg, err := s.loadConfig() + if err != nil { + return fmt.Errorf("error getting config: %v", err) + } - dc, ok := cli.(*indexd_client.IndexDClient) - if !ok { - return fmt.Errorf("cli is not IndexdClient: %T", cli) - } - myLogger.Debug(fmt.Sprintf("Current server: %s", dc.ProjectId)) + gitRemoteName, gitRemoteLocation := parseRemoteArgs(args) + myLogger.Debug(fmt.Sprintf("git remote name: %s, git remote location: %s", gitRemoteName, gitRemoteLocation)) - // Buffer stdin to a temp file and invoke `git lfs pre-push ` with same args and stdin. - tmp, err := os.CreateTemp("", "prepush-stdin-*") - if err != nil { - myLogger.Debug(fmt.Sprintf("error creating temp file for stdin: %v", err)) - return err - } - defer func() { - _ = tmp.Close() - _ = os.Remove(tmp.Name()) - }() - - // Copy all of stdin into the temp file. - if _, err := io.Copy(tmp, os.Stdin); err != nil { - myLogger.Debug(fmt.Sprintf("error buffering stdin: %v", err)) - return err - } + remote, err := cfg.GetDefaultRemote() + if err != nil { + myLogger.Debug(fmt.Sprintf("Warning. Error getting default remote: %v", err)) + fmt.Fprintln(os.Stderr, "Warning. Skipping DRS preparation. Error getting default remote:", err) + return nil + } - // Rewind to start so the child process can read it. - if _, err := tmp.Seek(0, 0); err != nil { - myLogger.Debug(fmt.Sprintf("error seeking temp stdin: %v", err)) - return err - } + remoteConfig := cfg.GetRemote(remote) + if remoteConfig == nil { + fmt.Fprintln(os.Stderr, "Warning. Skipping DRS preparation. Error getting remote configuration.") + myLogger.Debug("Warning. Skipping DRS preparation. Error getting remote configuration.") + return nil + } - // read the temp file and get a list of all unique local branches being pushed - branches, err := readPushedBranches(tmp) - if err != nil { - myLogger.Debug(fmt.Sprintf("error reading pushed branches: %v", err)) - return err - } + builder := drs.NewObjectBuilder(remoteConfig.GetBucketName(), remoteConfig.GetProjectId()) + myLogger.Debug(fmt.Sprintf("Current server project: %s", builder.ProjectID)) - myLogger.Debug(fmt.Sprintf("Preparing DRS objects for push branches: %v", branches)) - err = drsmap.UpdateDrsObjects(cli, gitRemoteName, gitRemoteLocation, branches, myLogger) - if err != nil { - myLogger.Debug(fmt.Sprintf("UpdateDrsObjects failed: %v", err)) - return err - } - myLogger.Debug("DRS objects prepared for push!") + tmp, err := bufferStdin(stdin, s.createTempFile) + if err != nil { + myLogger.Error(fmt.Sprintf("error buffering stdin: %v", err)) + return err + } + defer func() { + _ = tmp.Close() + _ = os.Remove(tmp.Name()) + }() - myLogger.Debug("~~~~~~~~~~~~~ COMPLETED: pre-push ~~~~~~~~~~~~~") - return nil - }, + refs, err := readPushedRefs(tmp) + if err != nil { + myLogger.Error(fmt.Sprintf("error reading pushed refs: %v", err)) + return err + } + branches := branchesFromRefs(refs) + + cache, cacheReady := openCache(ctx, myLogger) + lfsFiles, usedCache, err := collectLfsFiles(ctx, cache, cacheReady, gitRemoteName, gitRemoteLocation, branches, refs, myLogger) + if err != nil { + myLogger.Error(fmt.Sprintf("error collecting LFS files: %v", err)) + return err + } + + myLogger.Debug(fmt.Sprintf("Preparing DRS objects for push branches: %v (cache=%v)", branches, usedCache)) + err = s.updateDrsObjects(builder, lfsFiles, cache, usedCache, myLogger) + if err != nil { + myLogger.Error(fmt.Sprintf("UpdateDrsObjects failed: %v", err)) + return err + } + myLogger.Info("~~~~~~~~~~~~~ COMPLETED: pre-push ~~~~~~~~~~~~~") + return nil +} + +func parseRemoteArgs(args []string) (string, string) { + var gitRemoteName, gitRemoteLocation string + if len(args) >= 1 { + gitRemoteName = args[0] + } + if len(args) >= 2 { + gitRemoteLocation = args[1] + } + if gitRemoteName == "" { + gitRemoteName = "origin" + } + return gitRemoteName, gitRemoteLocation +} + +type pushedRef struct { + LocalRef string + LocalSHA string + RemoteRef string + RemoteSHA string +} + +func bufferStdin(stdin io.Reader, createTempFile func(dir, pattern string) (*os.File, error)) (*os.File, error) { + tmp, err := createTempFile("", "prepush-stdin-*") + if err != nil { + return nil, fmt.Errorf("error creating temp file for stdin: %w", err) + } + + if _, err := io.Copy(tmp, stdin); err != nil { + return nil, fmt.Errorf("error buffering stdin: %w", err) + } + + if _, err := tmp.Seek(0, 0); err != nil { + return nil, fmt.Errorf("error seeking temp stdin: %w", err) + } + return tmp, nil } // readPushedBranches reads git push lines from the provided temp file, // extracts unique local branch names for refs under `refs/heads/` and // returns them sorted. The file is rewound to the start before returning. -func readPushedBranches(f *os.File) ([]string, error) { +func readPushedRefs(f io.ReadSeeker) ([]pushedRef, error) { // Ensure we read from start // example: // refs/heads/main 67890abcdef1234567890abcdef1234567890abcd refs/heads/main 12345abcdef67890abcdef1234567890abcdef12 @@ -131,33 +164,160 @@ func readPushedBranches(f *os.File) ([]string, error) { return nil, err } scanner := bufio.NewScanner(f) - set := make(map[string]struct{}) + refs := make([]pushedRef, 0) for scanner.Scan() { line := scanner.Text() fields := strings.Fields(line) - if len(fields) < 1 { + if len(fields) < 4 { continue } - localRef := fields[0] - const prefix = "refs/heads/" - if strings.HasPrefix(localRef, prefix) { - branch := strings.TrimPrefix(localRef, prefix) + refs = append(refs, pushedRef{ + LocalRef: fields[0], + LocalSHA: fields[1], + RemoteRef: fields[2], + RemoteSHA: fields[3], + }) + } + if err := scanner.Err(); err != nil { + return nil, err + } + // Rewind so caller can reuse the file + if _, err := f.Seek(0, 0); err != nil { + return nil, err + } + return refs, nil +} + +func branchesFromRefs(refs []pushedRef) []string { + const prefix = "refs/heads/" + set := make(map[string]struct{}) + for _, ref := range refs { + if strings.HasPrefix(ref.LocalRef, prefix) { + branch := strings.TrimPrefix(ref.LocalRef, prefix) if branch != "" { set[branch] = struct{}{} } } } - if err := scanner.Err(); err != nil { - return nil, err - } branches := make([]string, 0, len(set)) for b := range set { branches = append(branches, b) } sort.Strings(branches) - // Rewind so caller can reuse the file - if _, err := f.Seek(0, 0); err != nil { - return nil, err + return branches +} + +func openCache(ctx context.Context, logger *slog.Logger) (*precommit_cache.Cache, bool) { + cache, err := precommit_cache.Open(ctx) + if err != nil { + logger.Debug(fmt.Sprintf("pre-commit cache unavailable: %v", err)) + return nil, false + } + if _, err := os.Stat(cache.Root); err != nil { + if os.IsNotExist(err) { + logger.Debug("pre-commit cache missing; continuing without cache") + } else { + logger.Debug(fmt.Sprintf("pre-commit cache access error: %v", err)) + } + return nil, false + } + return cache, true +} + +func collectLfsFiles(ctx context.Context, cache *precommit_cache.Cache, cacheReady bool, gitRemoteName, gitRemoteLocation string, branches []string, refs []pushedRef, logger *slog.Logger) (map[string]drslfs.LfsFileInfo, bool, error) { + if cacheReady { + lfsFiles, ok, err := lfsFilesFromCache(ctx, cache, refs, logger) + if err != nil { + logger.Debug(fmt.Sprintf("pre-commit cache read failed: %v", err)) + } else if ok { + return lfsFiles, true, nil + } + logger.Debug("pre-commit cache incomplete or stale; falling back to LFS discovery") + } + lfsFiles, err := drslfs.GetAllLfsFiles(gitRemoteName, gitRemoteLocation, branches, logger) + if err != nil { + return nil, false, err + } + return lfsFiles, false, nil +} + +const cacheMaxAge = 24 * time.Hour + +func lfsFilesFromCache(ctx context.Context, cache *precommit_cache.Cache, refs []pushedRef, logger *slog.Logger) (map[string]drslfs.LfsFileInfo, bool, error) { + if cache == nil { + return nil, false, nil + } + paths, err := listPushedPaths(ctx, refs) + if err != nil { + return nil, false, err + } + lfsFiles := make(map[string]drslfs.LfsFileInfo, len(paths)) + for _, path := range paths { + entry, ok, err := cache.ReadPathEntry(path) + if err != nil { + return nil, false, err + } + if !ok || entry.LFSOID == "" { + return nil, false, nil + } + if entry.UpdatedAt == "" || precommit_cache.StaleAfter(entry.UpdatedAt, cacheMaxAge) { + return nil, false, nil + } + stat, err := os.Stat(path) + if err != nil { + logger.Debug(fmt.Sprintf("cache path stat failed for %s: %v", path, err)) + return nil, false, nil + } + lfsFiles[path] = drslfs.LfsFileInfo{ + Name: path, + Size: stat.Size(), + OidType: "sha256", + Oid: entry.LFSOID, + Version: "https://git-lfs.github.com/spec/v1", + } + } + return lfsFiles, true, nil +} + +func listPushedPaths(ctx context.Context, refs []pushedRef) ([]string, error) { + const zeroSHA = "0000000000000000000000000000000000000000" + set := make(map[string]struct{}) + for _, ref := range refs { + if ref.LocalSHA == "" || ref.LocalSHA == zeroSHA { + continue + } + var args []string + if ref.RemoteSHA == "" || ref.RemoteSHA == zeroSHA { + args = []string{"ls-tree", "-r", "--name-only", ref.LocalSHA} + } else { + args = []string{"diff", "--name-only", ref.RemoteSHA, ref.LocalSHA} + } + out, err := gitOutput(ctx, args...) + if err != nil { + return nil, err + } + for _, line := range strings.Split(strings.TrimSpace(out), "\n") { + line = strings.TrimSpace(line) + if line == "" { + continue + } + set[line] = struct{}{} + } + } + paths := make([]string, 0, len(set)) + for path := range set { + paths = append(paths, path) + } + sort.Strings(paths) + return paths, nil +} + +func gitOutput(ctx context.Context, args ...string) (string, error) { + cmd := exec.CommandContext(ctx, "git", args...) + cmd.Env = os.Environ() + out, err := cmd.CombinedOutput() + if err != nil { + return "", fmt.Errorf("git %s: %s", strings.Join(args, " "), strings.TrimSpace(string(out))) } - return branches, nil + return string(out), nil } diff --git a/cmd/prepush/main_test.go b/cmd/prepush/main_test.go new file mode 100644 index 00000000..1af26146 --- /dev/null +++ b/cmd/prepush/main_test.go @@ -0,0 +1,205 @@ +package prepush + +import ( + "context" + "encoding/json" + "io" + "log/slog" + "os" + "os/exec" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/calypr/git-drs/precommit_cache" +) + +func TestLfsFilesFromCache(t *testing.T) { + repo := setupGitRepo(t) + filePath := filepath.Join(repo, "data", "file.bin") + if err := os.MkdirAll(filepath.Dir(filePath), 0o755); err != nil { + t.Fatalf("mkdir: %v", err) + } + if err := os.WriteFile(filePath, []byte("first"), 0o644); err != nil { + t.Fatalf("write: %v", err) + } + gitCmd(t, repo, "add", "data/file.bin") + gitCmd(t, repo, "commit", "-m", "first") + oldSHA := gitOutputString(t, repo, "rev-parse", "HEAD") + + if err := os.WriteFile(filePath, []byte("second"), 0o644); err != nil { + t.Fatalf("write: %v", err) + } + gitCmd(t, repo, "add", "data/file.bin") + gitCmd(t, repo, "commit", "-m", "second") + newSHA := gitOutputString(t, repo, "rev-parse", "HEAD") + + cacheRoot := filepath.Join(repo, ".git", "drs", "pre-commit", "v1") + cache := &precommit_cache.Cache{ + GitDir: filepath.Join(repo, ".git"), + Root: cacheRoot, + PathsDir: filepath.Join(cacheRoot, "paths"), + OIDsDir: filepath.Join(cacheRoot, "oids"), + StatePath: filepath.Join(cacheRoot, "state.json"), + } + if err := os.MkdirAll(cache.PathsDir, 0o755); err != nil { + t.Fatalf("mkdir paths dir: %v", err) + } + if err := os.MkdirAll(cache.OIDsDir, 0o755); err != nil { + t.Fatalf("mkdir oids dir: %v", err) + } + + pathEntry := precommit_cache.PathEntry{ + Path: "data/file.bin", + LFSOID: "oid-123", + UpdatedAt: time.Now().UTC().Format(time.RFC3339), + } + pathEntryFile := filepath.Join(cache.PathsDir, precommit_cache.EncodePath(pathEntry.Path)+".json") + writeJSON(t, pathEntryFile, pathEntry) + + oldwd := mustChdir(t, repo) + t.Cleanup(func() { _ = os.Chdir(oldwd) }) + + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + refs := []pushedRef{{ + LocalRef: "refs/heads/main", + LocalSHA: newSHA, + RemoteRef: "refs/heads/main", + RemoteSHA: oldSHA, + }} + + lfsFiles, ok, err := lfsFilesFromCache(context.Background(), cache, refs, logger) + if err != nil { + t.Fatalf("lfsFilesFromCache: %v", err) + } + if !ok { + t.Fatalf("expected cache to be usable") + } + info, exists := lfsFiles["data/file.bin"] + if !exists { + t.Fatalf("expected lfs info for data/file.bin") + } + if info.Oid != "oid-123" { + t.Fatalf("expected oid to be oid-123, got %s", info.Oid) + } + if info.OidType != "sha256" { + t.Fatalf("expected oid type sha256, got %s", info.OidType) + } + stat, err := os.Stat(filePath) + if err != nil { + t.Fatalf("stat: %v", err) + } + if info.Size != stat.Size() { + t.Fatalf("expected size %d, got %d", stat.Size(), info.Size) + } +} + +func TestLfsFilesFromCacheStale(t *testing.T) { + repo := setupGitRepo(t) + filePath := filepath.Join(repo, "data", "file.bin") + if err := os.MkdirAll(filepath.Dir(filePath), 0o755); err != nil { + t.Fatalf("mkdir: %v", err) + } + if err := os.WriteFile(filePath, []byte("data"), 0o644); err != nil { + t.Fatalf("write: %v", err) + } + gitCmd(t, repo, "add", "data/file.bin") + gitCmd(t, repo, "commit", "-m", "first") + sha := gitOutputString(t, repo, "rev-parse", "HEAD") + + cacheRoot := filepath.Join(repo, ".git", "drs", "pre-commit", "v1") + cache := &precommit_cache.Cache{ + GitDir: filepath.Join(repo, ".git"), + Root: cacheRoot, + PathsDir: filepath.Join(cacheRoot, "paths"), + OIDsDir: filepath.Join(cacheRoot, "oids"), + StatePath: filepath.Join(cacheRoot, "state.json"), + } + if err := os.MkdirAll(cache.PathsDir, 0o755); err != nil { + t.Fatalf("mkdir paths dir: %v", err) + } + + pathEntry := precommit_cache.PathEntry{ + Path: "data/file.bin", + LFSOID: "oid-123", + UpdatedAt: time.Now().Add(-48 * time.Hour).UTC().Format(time.RFC3339), + } + pathEntryFile := filepath.Join(cache.PathsDir, precommit_cache.EncodePath(pathEntry.Path)+".json") + writeJSON(t, pathEntryFile, pathEntry) + + oldwd := mustChdir(t, repo) + t.Cleanup(func() { _ = os.Chdir(oldwd) }) + + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + refs := []pushedRef{{ + LocalRef: "refs/heads/main", + LocalSHA: sha, + RemoteRef: "refs/heads/main", + RemoteSHA: "0000000000000000000000000000000000000000", + }} + + _, ok, err := lfsFilesFromCache(context.Background(), cache, refs, logger) + if err != nil { + t.Fatalf("lfsFilesFromCache: %v", err) + } + if ok { + t.Fatalf("expected cache to be stale") + } +} + +func setupGitRepo(t *testing.T) string { + t.Helper() + dir := t.TempDir() + gitCmd(t, dir, "init") + gitCmd(t, dir, "config", "user.email", "test@example.com") + gitCmd(t, dir, "config", "user.name", "Test User") + return dir +} + +func gitCmd(t *testing.T, dir string, args ...string) { + t.Helper() + cmd := exec.Command("git", args...) + cmd.Dir = dir + out, err := cmd.CombinedOutput() + if err != nil { + t.Fatalf("git %s failed: %v (%s)", strings.Join(args, " "), err, string(out)) + } +} + +func gitOutputString(t *testing.T, dir string, args ...string) string { + t.Helper() + cmd := exec.Command("git", args...) + cmd.Dir = dir + out, err := cmd.CombinedOutput() + if err != nil { + t.Fatalf("git %s failed: %v (%s)", strings.Join(args, " "), err, string(out)) + } + return strings.TrimSpace(string(out)) +} + +func writeJSON(t *testing.T, path string, value any) { + t.Helper() + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + t.Fatalf("mkdir: %v", err) + } + data, err := json.Marshal(value) + if err != nil { + t.Fatalf("marshal: %v", err) + } + if err := os.WriteFile(path, data, 0o644); err != nil { + t.Fatalf("write: %v", err) + } +} + +func mustChdir(t *testing.T, dir string) string { + t.Helper() + old, err := os.Getwd() + if err != nil { + t.Fatalf("Getwd: %v", err) + } + if err := os.Chdir(dir); err != nil { + t.Fatalf("Chdir(%s): %v", dir, err) + } + return old +} diff --git a/cmd/root.go b/cmd/root.go index 5bcc71b2..6f4792a8 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -11,6 +11,7 @@ import ( "github.com/calypr/git-drs/cmd/initialize" "github.com/calypr/git-drs/cmd/list" "github.com/calypr/git-drs/cmd/listconfig" + "github.com/calypr/git-drs/cmd/precommit" "github.com/calypr/git-drs/cmd/prepush" "github.com/calypr/git-drs/cmd/push" "github.com/calypr/git-drs/cmd/query" @@ -43,12 +44,13 @@ func init() { RootCmd.AddCommand(list.Cmd) RootCmd.AddCommand(list.ListProjectCmd) RootCmd.AddCommand(listconfig.Cmd) + RootCmd.AddCommand(precommit.Cmd) RootCmd.AddCommand(prepush.Cmd) RootCmd.AddCommand(query.Cmd) RootCmd.AddCommand(transfer.Cmd) RootCmd.AddCommand(transferref.Cmd) RootCmd.AddCommand(version.Cmd) - RootCmd.AddCommand(addurl.AddURLCmd) + RootCmd.AddCommand(addurl.Cmd) RootCmd.AddCommand(remote.Cmd) RootCmd.AddCommand(fetch.Cmd) RootCmd.AddCommand(push.Cmd) diff --git a/cmd/transfer/main.go b/cmd/transfer/main.go index 917c10d2..a7dbf5b8 100644 --- a/cmd/transfer/main.go +++ b/cmd/transfer/main.go @@ -9,12 +9,12 @@ import ( "github.com/bytedance/sonic/encoder" "github.com/calypr/data-client/client/common" "github.com/calypr/git-drs/client" + "github.com/calypr/git-drs/cloud" "github.com/calypr/git-drs/config" "github.com/calypr/git-drs/drslog" "github.com/calypr/git-drs/drsmap" "github.com/calypr/git-drs/lfs" "github.com/calypr/git-drs/projectdir" - "github.com/calypr/git-drs/s3_utils" "github.com/spf13/cobra" ) @@ -147,7 +147,16 @@ var Cmd = &cobra.Command{ if err != nil { errMsg := fmt.Sprintf("Error getting signed URL for OID %s: %v", downloadMsg.Oid, err) logger.Error(errMsg) - lfs.WriteErrorMessage(streamEncoder, downloadMsg.Oid, 502, errMsg) + + drsObject, errG := drsmap.DrsInfoFromOid(downloadMsg.Oid) + if errG == nil && drsObject != nil { + manualDownloadMsg := fmt.Sprintf("%s %s", drsObject.AccessMethods[0].AccessURL.URL, drsObject.Name) + logger.Info(manualDownloadMsg) + lfs.WriteErrorMessage(streamEncoder, downloadMsg.Oid, 302, manualDownloadMsg) + } else { + logger.Error(fmt.Sprintf("drsClient.GetObject failed for %s: %v ", downloadMsg.Oid, errG)) + lfs.WriteErrorMessage(streamEncoder, downloadMsg.Oid, 502, errMsg) + } continue } if accessUrl.URL == "" { @@ -165,7 +174,7 @@ var Cmd = &cobra.Command{ lfs.WriteErrorMessage(streamEncoder, downloadMsg.Oid, 400, errMsg) continue } - err = s3_utils.DownloadSignedUrl(accessUrl.URL, dstPath) + err = cloud.DownloadSignedUrl(accessUrl.URL, dstPath) if err != nil { errMsg := fmt.Sprintf("Error downloading file for OID %s: %v", downloadMsg.Oid, err) logger.Error(errMsg) diff --git a/config/config.go b/config/config.go index c00edbbb..d7521651 100644 --- a/config/config.go +++ b/config/config.go @@ -68,7 +68,8 @@ type Config struct { func (c Config) GetRemoteClient(remote Remote, logger *slog.Logger) (client.DRSClient, error) { x, ok := c.Remotes[remote] if !ok { - return nil, fmt.Errorf("GetRemoteClient no remote configuration found for current remote: %s", remote) + path, _ := c.ConfigPath() + return nil, fmt.Errorf("GetRemoteClient no remote configuration found for current remote: %s path: %s", remote, path) } if x.Gen3 != nil { configText, _ := yaml.Marshal(x.Gen3) @@ -146,6 +147,10 @@ func getConfigPath() (string, error) { return configPath, nil } +func (c Config) ConfigPath() (string, error) { + return getConfigPath() +} + // updates and git adds a Git DRS config file // this should handle three cases: // 1. create a new config file if it does not exist / is empty diff --git a/coverage/combined.html b/coverage/combined.html index 01f6dde0..1aae2114 100644 --- a/coverage/combined.html +++ b/coverage/combined.html @@ -57,113 +57,133 @@ - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + + + + + + + + + + + + + + + + + + + + + @@ -326,10 +346,10 @@ return "" } -func (s AnvilRemote) GetClient(params map[string]string, logger *slog.Logger) (client.DRSClient, error) { +func (s AnvilRemote) GetClient(params map[string]string, logger *slog.Logger) (client.DRSClient, error) { return nil, fmt.Errorf(("AnVIL Client needs to be implemented")) // return NewAnvilClient(s, logger) -} +} @@ -956,12 +976,12 @@ "github.com/calypr/data-client/client/logs" "github.com/calypr/data-client/client/upload" "github.com/calypr/git-drs/client" + "github.com/calypr/git-drs/cloud" "github.com/calypr/git-drs/drs" "github.com/calypr/git-drs/drs/hash" "github.com/calypr/git-drs/drslog" "github.com/calypr/git-drs/drsmap" "github.com/calypr/git-drs/projectdir" - "github.com/calypr/git-drs/s3_utils" "github.com/calypr/git-drs/utils" "github.com/hashicorp/go-multierror" "github.com/hashicorp/go-retryablehttp" @@ -974,7 +994,7 @@ ProjectId string BucketName string Logger *slog.Logger - AuthHandler s3_utils.AuthHandler // Injected for testing/flexibility + AuthHandler cloud.AuthHandler // Injected for testing/flexibility HttpClient *retryablehttp.Client SConfig sonic.API @@ -988,21 +1008,21 @@ //////////////////// // load repo-level config and return a new IndexDClient -func NewIndexDClient(profileConfig conf.Credential, remote Gen3Remote, logger *slog.Logger) (client.DRSClient, error) { +func NewIndexDClient(profileConfig conf.Credential, remote Gen3Remote, logger *slog.Logger) (client.DRSClient, error) { baseUrl, err := url.Parse(profileConfig.APIEndpoint) // get the gen3Project and gen3Bucket from the config - projectId := remote.GetProjectId() - if projectId == "" { - return nil, fmt.Errorf("no gen3 project specified. Run 'git drs init', use the '--help' flag for more info") + projectId := remote.GetProjectId() + if projectId == "" { + return nil, fmt.Errorf("no gen3 project specified. Run 'git drs init', use the '--help' flag for more info") } - bucketName := remote.GetBucketName() - if bucketName == "" { - logger.Debug("WARNING: no gen3 bucket specified. To add a bucket, run 'git remote add gen3', use the '--help' flag for more info") + bucketName := remote.GetBucketName() + if bucketName == "" { + logger.Debug("WARNING: no gen3 bucket specified. To add a bucket, run 'git remote add gen3', use the '--help' flag for more info") } - transport := &http.Transport{ + transport := &http.Transport{ MaxIdleConns: 100, // Default pool size (across all hosts) MaxIdleConnsPerHost: 100, // Important: Pool size per *single host* (your Indexd server) IdleConnTimeout: 90 * time.Second, @@ -1015,18 +1035,18 @@ retryClient.HTTPClient = httpClient // Custom CheckRetry: do not retry when response body contains "already exists" - retryClient.CheckRetry = func(ctx context.Context, resp *http.Response, err error) (bool, error) { + retryClient.CheckRetry = func(ctx context.Context, resp *http.Response, err error) (bool, error) { if resp != nil && resp.StatusCode < 500 && resp.StatusCode >= 400 { // do not retry on 4xx // 400 => "The request could not be understood by the // server due to malformed syntax". return false, nil } - if resp != nil && resp.Body != nil { + if resp != nil && resp.Body != nil { bodyBytes, readErr := io.ReadAll(resp.Body) // restore body for downstream consumers resp.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) - if readErr == nil { + if readErr == nil { if strings.Contains(string(bodyBytes), "already exists") { // do not retry on "already exists" messages return false, nil @@ -1034,10 +1054,10 @@ } } // fallback to default policy - return retryablehttp.DefaultRetryPolicy(ctx, resp, err) + return retryablehttp.DefaultRetryPolicy(ctx, resp, err) } - retryClient.Logger = drslog.AsStdLogger(logger) + retryClient.Logger = drslog.AsStdLogger(logger) // TODO - make these configurable? retryClient.RetryMax = 5 retryClient.RetryWaitMin = 5 * time.Second @@ -1048,7 +1068,7 @@ return nil, err } - multiPartThresholdInt, err := getLfsCustomTransferInt("lfs.customtransfer.drs.multipart-threshold", 500) + multiPartThresholdInt, err := getLfsCustomTransferInt("lfs.customtransfer.drs.multipart-threshold", 500) var multiPartThreshold int64 = multiPartThresholdInt * common.MB // default 100 MB return &IndexDClient{ @@ -1064,11 +1084,11 @@ }, nil } -func (cl *IndexDClient) GetProjectId() string { +func (cl *IndexDClient) GetProjectId() string { return cl.ProjectId } -func getLfsCustomTransferBool(key string, defaultValue bool) (bool, error) { +func getLfsCustomTransferBool(key string, defaultValue bool) (bool, error) { defaultText := strconv.FormatBool(defaultValue) // TODO cache or get all the configs at once? cmd := exec.Command("git", "config", "--get", "--default", defaultText, key) @@ -1077,16 +1097,16 @@ return defaultValue, fmt.Errorf("error reading git config %s: %v", key, err) } - value := strings.TrimSpace(string(output)) + value := strings.TrimSpace(string(output)) parsed, err := strconv.ParseBool(value) if err != nil { return defaultValue, fmt.Errorf("invalid boolean value for %s: >%q<", key, value) } - return parsed, nil + return parsed, nil } -func getLfsCustomTransferInt(key string, defaultValue int64) (int64, error) { +func getLfsCustomTransferInt(key string, defaultValue int64) (int64, error) { defaultText := strconv.FormatInt(defaultValue, 10) // TODO cache or get all the configs at once? cmd := exec.Command("git", "config", "--get", "--default", defaultText, key) @@ -1095,27 +1115,27 @@ return defaultValue, fmt.Errorf("error reading git config %s: %v", key, err) } - value := strings.TrimSpace(string(output)) + value := strings.TrimSpace(string(output)) parsed, err := strconv.ParseInt(value, 10, 64) if err != nil { return defaultValue, fmt.Errorf("invalid int value for %s: >%q<", key, value) } - if parsed < 1 || parsed > 500 { + if parsed < 1 || parsed > 500 { return defaultValue, fmt.Errorf("invalid int value for %s: %d. Must be between 1 and 500", key, parsed) } - return parsed, nil + return parsed, nil } // GetProfile extracts the profile from the auth handler if available // This is only needed for external APIs like g3cmd that require it -func (cl *IndexDClient) GetProfile() (string, error) { - if rh, ok := cl.AuthHandler.(*RealAuthHandler); ok { +func (cl *IndexDClient) GetProfile() (string, error) { + if rh, ok := cl.AuthHandler.(*RealAuthHandler); ok { return rh.Cred.Profile, nil } - return "", fmt.Errorf("AuthHandler is not RealAuthHandler, cannot extract profile") + return "", fmt.Errorf("AuthHandler is not RealAuthHandler, cannot extract profile") } func (cl *IndexDClient) DeleteRecordsByProject(projectId string) error { @@ -1266,7 +1286,7 @@ } // getDownloadURL gets a signed URL for the given DRS ID and accessType (eg s3) -func (cl *IndexDClient) getDownloadURL(did string, accessType string) (drs.AccessURL, error) { +func (cl *IndexDClient) getDownloadURL(did string, accessType string) (drs.AccessURL, error) { // get signed url a := *cl.Base a.Path = filepath.Join(a.Path, "ga4gh/drs/v1/objects", did, "access", accessType) @@ -1276,22 +1296,22 @@ return drs.AccessURL{}, err } - err = cl.AuthHandler.AddAuthHeader(req.Request) + err = cl.AuthHandler.AddAuthHeader(req.Request) if err != nil { return drs.AccessURL{}, fmt.Errorf("error adding Gen3 auth header: %v", err) } - response, err := cl.HttpClient.Do(req) + response, err := cl.HttpClient.Do(req) if err != nil { return drs.AccessURL{}, fmt.Errorf("error getting signed URL: %v", err) } - defer func() { + defer func() { if closeErr := response.Body.Close(); closeErr != nil { cl.Logger.Debug(fmt.Sprintf("error closing response body: %v", closeErr)) } }() - accessUrl := drs.AccessURL{} + accessUrl := drs.AccessURL{} // read full body so we can both decode and include it in any error bodyBytes, readErr := io.ReadAll(response.Body) @@ -1299,16 +1319,16 @@ return drs.AccessURL{}, fmt.Errorf("unable to read response body: %v", readErr) } - if err := cl.SConfig.Unmarshal(bodyBytes, &accessUrl); err != nil { + if err := cl.SConfig.Unmarshal(bodyBytes, &accessUrl); err != nil { return drs.AccessURL{}, fmt.Errorf("unable to decode response into drs.AccessURL: %v; body: %s", err, string(bodyBytes)) } // check if empty - if accessUrl.URL == "" { + if accessUrl.URL == "" { return drs.AccessURL{}, fmt.Errorf("signed url is empty %#v %s", accessUrl, response.Status) } - cl.Logger.Debug(fmt.Sprintf("signed url retrieved: %s", response.Status)) + cl.Logger.Debug(fmt.Sprintf("signed url retrieved: %s", response.Status)) return accessUrl, nil } @@ -1318,7 +1338,7 @@ // is not already available in the bucket, and returns the resulting DRS object. // When registration fails without force push, it retries once with force push // enabled to reuse existing records and avoid duplicate uploads. -func (cl *IndexDClient) RegisterFile(oid string, progressCallback common.ProgressCallback) (*drs.DRSObject, error) { +func (cl *IndexDClient) RegisterFile(oid string, progressCallback common.ProgressCallback) (*drs.DRSObject, error) { cl.Logger.Debug(fmt.Sprintf("register file started for oid: %s", oid)) // load the DRS object from oid created by prepush @@ -1328,13 +1348,13 @@ } // convert to indexd record - indexdObj, err := indexdRecordFromDrsObject(drsObject) + indexdObj, err := indexdRecordFromDrsObject(drsObject) if err != nil { return nil, fmt.Errorf("error converting DRS object oid %s to indexd record: %v", oid, err) } // save the indexd record - _, err = cl.RegisterIndexdRecord(indexdObj) + _, err = cl.RegisterIndexdRecord(indexdObj) if err != nil { // handle "already exists" error ie upsert behavior if strings.Contains(err.Error(), "already exists") { @@ -1358,23 +1378,23 @@ } // Now attempt to upload the file if not already available - downloadable, err := cl.isFileDownloadable(drsObject) + downloadable, err := cl.isFileDownloadable(drsObject) if err != nil { return nil, fmt.Errorf("error checking if file is downloadable: oid %s %v", oid, err) } - if downloadable { + if downloadable { cl.Logger.Debug(fmt.Sprintf("file %s is already available for download, skipping upload", oid)) return drsObject, nil } // Proceed to upload the file ------------------- - profile, err := cl.GetProfile() + profile, err := cl.GetProfile() if err != nil { return nil, fmt.Errorf("error getting profile for upload: %v", err) } // TODO - should we deprecate this gen3-client style logger in favor of drslog.Logger? // TODO - or can we "wrap it" so both work together? - logger, closer := logs.New(profile, logs.WithBaseLogger(drslog.AsStdLogger(cl.Logger))) + logger, closer := logs.New(profile, logs.WithBaseLogger(drslog.AsStdLogger(cl.Logger))) defer closer() // Instantiate interface to Gen3 // TODO - Can we reuse this interface to avoid repeated config parsing and most likely repeated token refresh? @@ -1384,22 +1404,22 @@ return nil, fmt.Errorf("error creating Gen3 interface: %v", err) } - filePath, err := drsmap.GetObjectPath(projectdir.LFS_OBJS_PATH, oid) + filePath, err := drsmap.GetObjectPath(projectdir.LFS_OBJS_PATH, oid) if err != nil { return nil, fmt.Errorf("error getting object path for oid %s: %v", oid, err) } - file, err := os.Open(filePath) + file, err := os.Open(filePath) if err != nil { return nil, fmt.Errorf("error opening file %s: %v", filePath, err) } - defer func(file *os.File) { + defer func(file *os.File) { err := file.Close() if err != nil { cl.Logger.Debug(fmt.Sprintf("warning: error closing file %s: %v", filePath, err)) } }(file) - if drsObject.Size < cl.MultiPartThreshold { + if drsObject.Size < cl.MultiPartThreshold { cl.Logger.Debug(fmt.Sprintf("UploadSingle size: %d path: %s", drsObject.Size, filePath)) err := upload.UploadSingle(context.Background(), g3.GetCredential().Profile, drsObject.Id, drsObject.Checksums.SHA256, filePath, cl.BucketName, false, progressCallback) if err != nil { @@ -1425,30 +1445,30 @@ return nil, fmt.Errorf("MultipartUpload error: %s", err) } } - return drsObject, nil + return drsObject, nil } -func (cl *IndexDClient) isFileDownloadable(drsObject *drs.DRSObject) (bool, error) { +func (cl *IndexDClient) isFileDownloadable(drsObject *drs.DRSObject) (bool, error) { if drsObject == nil { return false, fmt.Errorf("drsObject is nil") } - if len(drsObject.AccessMethods) == 0 { + if len(drsObject.AccessMethods) == 0 { cl.Logger.Debug(fmt.Sprintf("DRS object %s has no access methods; proceeding to upload", drsObject.Id)) return false, nil } - cl.Logger.Debug(fmt.Sprintf("checking if %s file is downloadable %v %v %v", drsObject.Id, drsObject.AccessMethods[0].AccessID, drsObject.AccessMethods[0].Type, drsObject.AccessMethods[0].AccessURL)) + cl.Logger.Debug(fmt.Sprintf("checking if %s file is downloadable %v %v %v", drsObject.Id, drsObject.AccessMethods[0].AccessID, drsObject.AccessMethods[0].Type, drsObject.AccessMethods[0].AccessURL)) signedUrl, err := cl.getDownloadURL(drsObject.Id, drsObject.AccessMethods[0].Type) if err != nil { cl.Logger.Debug(fmt.Sprintf("error getting signed URL for file with oid %s: %s", drsObject.Id, err)) return false, fmt.Errorf("error getting signed URL for file with oid %s: %s", drsObject.Id, err) } - if signedUrl.URL == "" { + if signedUrl.URL == "" { return false, nil } - err = utils.CanDownloadFile(signedUrl.URL) - if err != nil { + err = utils.CanDownloadFile(signedUrl.URL) + if err != nil { cl.Logger.Debug(fmt.Sprintf("file with oid %s does not exist in bucket: %s", drsObject.Id, err)) return false, nil } @@ -1456,7 +1476,7 @@ return true, nil } -func (cl *IndexDClient) GetObject(id string) (*drs.DRSObject, error) { +func (cl *IndexDClient) GetObject(id string) (*drs.DRSObject, error) { a := *cl.Base a.Path = filepath.Join(a.Path, "ga4gh/drs/v1/objects", id) @@ -1466,30 +1486,30 @@ return nil, err } - err = cl.AuthHandler.AddAuthHeader(req.Request) + err = cl.AuthHandler.AddAuthHeader(req.Request) if err != nil { return nil, fmt.Errorf("error adding Gen3 auth header: %v", err) } - response, err := cl.HttpClient.Do(req) + response, err := cl.HttpClient.Do(req) if err != nil { return nil, err } - defer response.Body.Close() + defer response.Body.Close() if response.Status == "404" { return nil, fmt.Errorf("%s not found", id) } - in := drs.OutputObject{} + in := drs.OutputObject{} if err := cl.SConfig.NewDecoder(response.Body).Decode(&in); err != nil { return nil, err } - return drs.ConvertOutputObjectToDRSObject(&in), nil + return drs.ConvertOutputObjectToDRSObject(&in), nil } -func (cl *IndexDClient) ListObjectsByProject(projectId string) (chan drs.DRSObjectResult, error) { +func (cl *IndexDClient) ListObjectsByProject(projectId string) (chan drs.DRSObjectResult, error) { const PAGESIZE = 50 pageNum := 0 @@ -1499,26 +1519,26 @@ return nil, err } - a := *cl.Base + a := *cl.Base a.Path = filepath.Join(a.Path, "index/index") out := make(chan drs.DRSObjectResult, PAGESIZE) - go func() { + go func() { defer close(out) // This will hold all errors encountered during the loop var resultErrors *multierror.Error active := true - for active { + for active { req, err := retryablehttp.NewRequest("GET", a.String(), nil) if err != nil { resultErrors = multierror.Append(resultErrors, fmt.Errorf("request creation: %w", err)) break } - q := req.URL.Query() + q := req.URL.Query() q.Add("authz", resourcePath) q.Add("limit", fmt.Sprintf("%d", PAGESIZE)) q.Add("page", fmt.Sprintf("%d", pageNum)) @@ -1529,14 +1549,14 @@ break } - response, err := cl.HttpClient.Do(req) + response, err := cl.HttpClient.Do(req) if err != nil { resultErrors = multierror.Append(resultErrors, fmt.Errorf("http call: %w", err)) break } // Read body and close immediately - body, err := io.ReadAll(response.Body) + body, err := io.ReadAll(response.Body) response.Body.Close() if err != nil { @@ -1544,45 +1564,45 @@ break } - if response.StatusCode != http.StatusOK { + if response.StatusCode != http.StatusOK { resultErrors = multierror.Append(resultErrors, fmt.Errorf("api error %d: %s", response.StatusCode, string(body))) break } - page := &ListRecords{} + page := &ListRecords{} if err := cl.SConfig.Unmarshal(body, &page); err != nil { resultErrors = multierror.Append(resultErrors, fmt.Errorf("unmarshal: %w", err)) break } - if len(page.Records) == 0 { + if len(page.Records) == 0 { active = false } - for _, elem := range page.Records { + for _, elem := range page.Records { drsObj, err := elem.ToIndexdRecord().ToDrsObject() if err != nil { // Append and keep going, or break if this is fatal resultErrors = multierror.Append(resultErrors, err) continue } - out <- drs.DRSObjectResult{Object: drsObj} + out <- drs.DRSObjectResult{Object: drsObj} } - pageNum++ + pageNum++ } // If we accumulated any errors, send the final concatenated result - if resultErrors != nil { + if resultErrors != nil { out <- drs.DRSObjectResult{Error: resultErrors.ErrorOrNil()} } }() - return out, nil + return out, nil } // given indexd record, constructs a new indexd record // implements /index/index POST -func (cl *IndexDClient) RegisterIndexdRecord(indexdObj *IndexdRecord) (*drs.DRSObject, error) { +func (cl *IndexDClient) RegisterIndexdRecord(indexdObj *IndexdRecord) (*drs.DRSObject, error) { indexdObjForm := IndexdRecordForm{ IndexdRecord: *indexdObj, Form: "object", @@ -1593,7 +1613,7 @@ return nil, err } - cl.Logger.Debug(fmt.Sprintf("writing IndexdObj: %s", string(jsonBytes))) + cl.Logger.Debug(fmt.Sprintf("writing IndexdObj: %s", string(jsonBytes))) // register DRS object via /index POST // (setup post request to indexd) @@ -1605,7 +1625,7 @@ return nil, err } // set Content-Type header for JSON - req.Header.Set("accept", "application/json") + req.Header.Set("accept", "application/json") req.Header.Set("Content-Type", "application/json") // add auth token @@ -1614,12 +1634,12 @@ return nil, fmt.Errorf("error adding Gen3 auth header: %v", err) } - cl.Logger.Debug(fmt.Sprintf("POST request created for indexd: %s", endpt.String())) + cl.Logger.Debug(fmt.Sprintf("POST request created for indexd: %s", endpt.String())) response, err := cl.HttpClient.Do(req) if err != nil { return nil, err } - defer response.Body.Close() + defer response.Body.Close() // check and see if the response status is OK drsId := indexdObjForm.Did @@ -1627,14 +1647,14 @@ body, _ := io.ReadAll(response.Body) return nil, fmt.Errorf("failed to register DRS ID %s: %s", drsId, body) } - cl.Logger.Debug(fmt.Sprintf("POST successful: %s", response.Status)) + cl.Logger.Debug(fmt.Sprintf("POST successful: %s", response.Status)) // removed re-query return DRS object (was missing access method authorization anyway) drsObj, err := indexdRecordToDrsObject(indexdObj) if err != nil { return nil, fmt.Errorf("error converting indexd record to DRS object: %w %v", err, indexdObj) } - return drsObj, nil + return drsObj, nil } // implements /index{did}?rev={rev} DELETE @@ -1676,7 +1696,7 @@ } // implements /index/index?hash={hashType}:{hash} GET -func (cl *IndexDClient) GetObjectByHash(sum *hash.Checksum) ([]drs.DRSObject, error) { +func (cl *IndexDClient) GetObjectByHash(sum *hash.Checksum) ([]drs.DRSObject, error) { // setup get request to indexd url := fmt.Sprintf("%s/index/index?hash=%s:%s", cl.Base.String(), sum.Type, sum.Checksum) cl.Logger.Debug(fmt.Sprintf("Querying indexd at %s", url)) @@ -1685,20 +1705,20 @@ cl.Logger.Debug(fmt.Sprintf("http.NewRequest Error: %s", err)) return nil, err } - cl.Logger.Debug(fmt.Sprintf("Looking for files with hash %s:%s", sum.Type, sum.Checksum)) + cl.Logger.Debug(fmt.Sprintf("Looking for files with hash %s:%s", sum.Type, sum.Checksum)) err = cl.AuthHandler.AddAuthHeader(req.Request) if err != nil { return nil, fmt.Errorf("unable to add authentication when searching for object: %s:%s. More on the error: %v", sum.Type, sum.Checksum, err) } - req.Header.Set("accept", "application/json") + req.Header.Set("accept", "application/json") // run request and do checks resp, err := cl.HttpClient.Do(req) if err != nil { return nil, fmt.Errorf("unable to check if server has files with hash %s:%s: %v", sum.Type, sum.Checksum, err) } - defer resp.Body.Close() + defer resp.Body.Close() if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) @@ -1706,13 +1726,13 @@ } // unmarshal response body - records := ListRecords{} + records := ListRecords{} err = cl.SConfig.NewDecoder(resp.Body).Decode(&records) if err != nil { return nil, fmt.Errorf("error unmarshaling (%s:%s): %v", sum.Type, sum.Checksum, err) } // log how many records were found - cl.Logger.Debug(fmt.Sprintf("Found %d indexd record(s) matching the hash %v", len(records.Records), records)) + cl.Logger.Debug(fmt.Sprintf("Found %d indexd record(s) matching the hash %v", len(records.Records), records)) out := make([]drs.DRSObject, 0, len(records.Records)) @@ -1721,18 +1741,18 @@ return out, nil } - resourcePath, _ := utils.ProjectToResource(cl.GetProjectId()) + resourcePath, _ := utils.ProjectToResource(cl.GetProjectId()) - for _, record := range records.Records { + for _, record := range records.Records { // skip records that do not authorize this client/project found := false - for _, a := range record.Authz { + for _, a := range record.Authz { if a == resourcePath { found = true break } } - if !found { + if !found { continue } @@ -1743,17 +1763,17 @@ out = append(out, *drsObj) } - return out, nil + return out, nil } // GetProjectSample retrieves a sample of DRS objects for a given project (limit: 1 by default) // Returns up to 'limit' records for preview purposes before destructive operations -func (cl *IndexDClient) GetProjectSample(projectId string, limit int) ([]drs.DRSObject, error) { - if limit <= 0 { +func (cl *IndexDClient) GetProjectSample(projectId string, limit int) ([]drs.DRSObject, error) { + if limit <= 0 { limit = 1 } - cl.Logger.Debug(fmt.Sprintf("Getting sample DRS objects from indexd for project %s (limit: %d)", projectId, limit)) + cl.Logger.Debug(fmt.Sprintf("Getting sample DRS objects from indexd for project %s (limit: %d)", projectId, limit)) // Reuse ListObjectsByProject and collect first 'limit' results objChan, err := cl.ListObjectsByProject(projectId) @@ -1761,30 +1781,30 @@ return nil, err } - result := make([]drs.DRSObject, 0, limit) - for objResult := range objChan { + result := make([]drs.DRSObject, 0, limit) + for objResult := range objChan { if objResult.Error != nil { return nil, objResult.Error } - result = append(result, *objResult.Object) + result = append(result, *objResult.Object) // Stop after collecting enough samples - if len(result) >= limit { + if len(result) >= limit { // Drain remaining results to avoid goroutine leak - go func() { + go func() { for range objChan { } }() - break + break } } - cl.Logger.Debug(fmt.Sprintf("Retrieved %d sample record(s)", len(result))) + cl.Logger.Debug(fmt.Sprintf("Retrieved %d sample record(s)", len(result))) return result, nil } // implements /index/index?authz={resource_path}&start={start}&limit={limit} GET -func (cl *IndexDClient) ListObjects() (chan drs.DRSObjectResult, error) { +func (cl *IndexDClient) ListObjects() (chan drs.DRSObjectResult, error) { cl.Logger.Debug("Getting DRS objects from indexd") @@ -1796,10 +1816,10 @@ LIMIT := 50 pageNum := 0 - go func() { + go func() { defer close(out) active := true - for active { + for active { // setup request req, err := retryablehttp.NewRequest("GET", a.String(), nil) if err != nil { @@ -1808,7 +1828,7 @@ return } - q := req.URL.Query() + q := req.URL.Query() q.Add("limit", fmt.Sprintf("%d", LIMIT)) q.Add("page", fmt.Sprintf("%d", pageNum)) req.URL.RawQuery = q.Encode() @@ -1821,7 +1841,7 @@ } // execute request with error checking - response, err := cl.HttpClient.Do(req) + response, err := cl.HttpClient.Do(req) if err != nil { cl.Logger.Debug(fmt.Sprintf("error: %s", err)) @@ -1829,44 +1849,44 @@ return } - defer response.Body.Close() + defer response.Body.Close() body, err := io.ReadAll(response.Body) if err != nil { cl.Logger.Debug(fmt.Sprintf("error: %s", err)) out <- drs.DRSObjectResult{Error: err} return } - if response.StatusCode != http.StatusOK { + if response.StatusCode != http.StatusOK { cl.Logger.Debug(fmt.Sprintf("%d: check that your credentials are valid \nfull message: %s", response.StatusCode, body)) out <- drs.DRSObjectResult{Error: fmt.Errorf("%d: check your credentials are valid, \nfull message: %s", response.StatusCode, body)} return } // return page of DRS objects - page := &drs.DRSPage{} + page := &drs.DRSPage{} err = cl.SConfig.Unmarshal(body, &page) if err != nil { cl.Logger.Debug(fmt.Sprintf("error: %s (%s)", err, body)) out <- drs.DRSObjectResult{Error: err} return } - for _, elem := range page.DRSObjects { + for _, elem := range page.DRSObjects { out <- drs.DRSObjectResult{Object: &elem} } - if len(page.DRSObjects) == 0 { + if len(page.DRSObjects) == 0 { active = false } - pageNum++ + pageNum++ } - cl.Logger.Debug(fmt.Sprintf("total pages retrieved: %d", pageNum)) + cl.Logger.Debug(fmt.Sprintf("total pages retrieved: %d", pageNum)) }() - return out, nil + return out, nil } // UpdateRecord updates an existing indexd record by GUID using the PUT /index/index/{guid} endpoint // Supports updating: URLs, name (file_name), description (metadata), version, and authz -func (cl *IndexDClient) UpdateRecord(updateInfo *drs.DRSObject, did string) (*drs.DRSObject, error) { +func (cl *IndexDClient) UpdateRecord(updateInfo *drs.DRSObject, did string) (*drs.DRSObject, error) { // Get current revision from existing record record, err := cl.GetIndexdRecordByDID(did) if err != nil { @@ -1874,7 +1894,7 @@ } // Build update payload starting with existing record values - updatePayload := UpdateInputInfo{ + updatePayload := UpdateInputInfo{ URLs: record.URLs, FileName: record.FileName, Version: record.Version, @@ -1885,13 +1905,13 @@ // Apply updates from updateInfo // Update URLs by appending new access methods (deduplicated) - if len(updateInfo.AccessMethods) > 0 { + if len(updateInfo.AccessMethods) > 0 { // Collect new URLs from access methods newURLs := make([]string, 0, len(updateInfo.AccessMethods)) - for _, a := range updateInfo.AccessMethods { + for _, a := range updateInfo.AccessMethods { newURLs = append(newURLs, a.AccessURL.URL) } - updatePayload.URLs = utils.AddUnique(updatePayload.URLs, newURLs) + updatePayload.URLs = utils.AddUnique(updatePayload.URLs, newURLs) // Append authz from access methods (deduplicated) authz := indexdAuthzFromDrsAccessMethods(updateInfo.AccessMethods) @@ -1899,29 +1919,29 @@ } // Update name (maps to file_name in indexd) - if updateInfo.Name != "" { + if updateInfo.Name != "" { updatePayload.FileName = updateInfo.Name } // Update version - if updateInfo.Version != "" { + if updateInfo.Version != "" { updatePayload.Version = updateInfo.Version } // Update description (stored in metadata) - if updateInfo.Description != "" { - if updatePayload.Metadata == nil { + if updateInfo.Description != "" { + if updatePayload.Metadata == nil { updatePayload.Metadata = make(map[string]any) } - updatePayload.Metadata["description"] = updateInfo.Description + updatePayload.Metadata["description"] = updateInfo.Description } - jsonBytes, err := cl.SConfig.Marshal(updatePayload) + jsonBytes, err := cl.SConfig.Marshal(updatePayload) if err != nil { return nil, fmt.Errorf("error marshaling indexd object form: %v", err) } - cl.Logger.Debug(fmt.Sprintf("Prepared updated indexd object for DID %s: %s", did, string(jsonBytes))) + cl.Logger.Debug(fmt.Sprintf("Prepared updated indexd object for DID %s: %s", did, string(jsonBytes))) // prepare URL updateURL := fmt.Sprintf("%s/index/index/%s?rev=%s", cl.Base.String(), did, record.Rev) @@ -1932,7 +1952,7 @@ } // Set required headers - req.Header.Set("accept", "application/json") + req.Header.Set("accept", "application/json") req.Header.Set("Content-Type", "application/json") err = cl.AuthHandler.AddAuthHeader(req.Request) @@ -1940,14 +1960,14 @@ return nil, fmt.Errorf("error adding Gen3 auth header: %v", err) } - cl.Logger.Debug(fmt.Sprintf("PUT request created for indexd update: %s", updateURL)) + cl.Logger.Debug(fmt.Sprintf("PUT request created for indexd update: %s", updateURL)) // Execute the request response, err := cl.HttpClient.Do(req) if err != nil { return nil, fmt.Errorf("error executing PUT request: %v", err) } - defer response.Body.Close() + defer response.Body.Close() // Check response status if response.StatusCode != http.StatusOK { @@ -1955,7 +1975,7 @@ return nil, fmt.Errorf("failed to update indexd record %s: status %d, body: %s", did, response.StatusCode, string(body)) } - cl.Logger.Debug(fmt.Sprintf("PUT request successful: %s", response.Status)) + cl.Logger.Debug(fmt.Sprintf("PUT request successful: %s", response.Status)) // Query and return the updated DRS object updatedDrsObj, err := cl.GetObject(did) @@ -1963,12 +1983,12 @@ return nil, fmt.Errorf("error retrieving updated DRS object: %v", err) } - cl.Logger.Debug(fmt.Sprintf("Successfully updated and retrieved DRS object: %s", did)) + cl.Logger.Debug(fmt.Sprintf("Successfully updated and retrieved DRS object: %s", did)) return updatedDrsObj, nil } // Helper function to get indexd record by DID (similar to existing pattern in DeleteIndexdRecord) -func (cl *IndexDClient) GetIndexdRecordByDID(did string) (*OutputInfo, error) { +func (cl *IndexDClient) GetIndexdRecordByDID(did string) (*OutputInfo, error) { url := fmt.Sprintf("%s/index/%s", cl.Base.String(), did) req, err := retryablehttp.NewRequest("GET", url, nil) @@ -1976,59 +1996,29 @@ return nil, err } - err = cl.AuthHandler.AddAuthHeader(req.Request) + err = cl.AuthHandler.AddAuthHeader(req.Request) if err != nil { return nil, fmt.Errorf("error adding Gen3 auth header: %v", err) } - req.Request.Header.Set("accept", "application/json") + req.Request.Header.Set("accept", "application/json") resp, err := cl.HttpClient.Do(req) if err != nil { return nil, err } - defer resp.Body.Close() + defer resp.Body.Close() if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) return nil, fmt.Errorf("failed to get record: status %d, body: %s", resp.StatusCode, string(body)) } - record := &OutputInfo{} + record := &OutputInfo{} if err := cl.SConfig.NewDecoder(resp.Body).Decode(record); err != nil { return nil, fmt.Errorf("error decoding response body: %v", err) } - return record, nil -} - -func (cl *IndexDClient) BuildDrsObj(fileName string, checksum string, size int64, drsId string) (*drs.DRSObject, error) { - bucket := cl.BucketName - if bucket == "" { - return nil, fmt.Errorf("error: bucket name is empty in config file") - } - - //TODO: support other storage backends - fileURL := fmt.Sprintf("s3://%s", filepath.Join(bucket, drsId, checksum)) - - authzStr, err := utils.ProjectToResource(cl.GetProjectId()) - if err != nil { - return nil, err - } - authorizations := drs.Authorizations{ - Value: authzStr, - } - - // create DrsObj - DrsObj := drs.DRSObject{ - Id: drsId, - Name: fileName, - // TODO: ensure that we can retrieve the access method during submission (happens in transfer) - AccessMethods: []drs.AccessMethod{{Type: "s3", AccessURL: drs.AccessURL{URL: fileURL}, Authorizations: &authorizations}}, - Checksums: hash.HashInfo{SHA256: checksum}, - Size: size, - } - - return &DrsObj, nil + return record, nil } // Helper function to get indexd record by DID (similar to existing pattern in DeleteIndexdRecord) @@ -2203,7 +2193,7 @@ "net/http" "github.com/bytedance/sonic" - "github.com/calypr/git-drs/s3_utils" + "github.com/calypr/git-drs/cloud" ) // getBucketDetailsWithAuth fetches bucket details from Gen3 using an AuthHandler. @@ -2214,7 +2204,7 @@ // - bucketsEndpointURL: full URL to the /user/data/buckets endpoint // - authHandler: handler for adding authentication headers // - httpClient: the HTTP client to use -func GetBucketDetailsWithAuth(ctx context.Context, bucket, bucketsEndpointURL string, authHandler s3_utils.AuthHandler, httpClient *http.Client) (*s3_utils.S3Bucket, error) { +func GetBucketDetailsWithAuth(ctx context.Context, bucket, bucketsEndpointURL string, authHandler cloud.AuthHandler, httpClient *http.Client) (*cloud.S3Bucket, error) { // Use provided client or create default if httpClient == nil { httpClient = &http.Client{} @@ -2243,7 +2233,7 @@ } // extract bucket endpoint - var bucketInfo s3_utils.S3BucketsResponse + var bucketInfo cloud.S3BucketsResponse if err := sonic.ConfigFastest.NewDecoder(resp.Body).Decode(&bucketInfo); err != nil { return nil, fmt.Errorf("failed to decode bucket information: %w", err) } @@ -2259,1080 +2249,2956 @@ } -