From 7ae334d561e50f3e2a64da6eef5f351314607bd2 Mon Sep 17 00:00:00 2001 From: Brian Walsh Date: Tue, 27 Jan 2026 21:53:31 -0800 Subject: [PATCH 01/52] initial #141 --- cmd/addurl/lfss3/agent_fetch_reader.go | 161 + cmd/addurl/lfss3/inspect.go | 400 + cmd/addurl/lfss3/inspect_test.go | 263 + cmd/addurl/lfss3/is_lfs_tracked_test.go | 64 + cmd/addurl/main.go | 324 +- cmd/addurl/main_test.go | 28 - cmd/root.go | 2 +- coverage/combined.html | 1879 ++-- coverage/combined.out | 742 +- coverage/integration/coverage.out | 243 +- coverage/unit/coverage.out | 10297 +++++++++++++--------- 11 files changed, 9005 insertions(+), 5398 deletions(-) create mode 100644 cmd/addurl/lfss3/agent_fetch_reader.go create mode 100644 cmd/addurl/lfss3/inspect.go create mode 100644 cmd/addurl/lfss3/inspect_test.go create mode 100644 cmd/addurl/lfss3/is_lfs_tracked_test.go delete mode 100644 cmd/addurl/main_test.go diff --git a/cmd/addurl/lfss3/agent_fetch_reader.go b/cmd/addurl/lfss3/agent_fetch_reader.go new file mode 100644 index 00000000..69fc68c7 --- /dev/null +++ b/cmd/addurl/lfss3/agent_fetch_reader.go @@ -0,0 +1,161 @@ +package lfss3 + +import ( + "context" + "fmt" + "io" + "net/http" + "net/url" + "os" + "strings" + "sync/atomic" + "time" +) + +// progressReader wraps an io.ReadCloser and periodically writes progress to stderr. +type progressReader struct { + rc io.ReadCloser + label string + start time.Time + total int64 // accessed atomically + quit chan struct{} + done chan struct{} + ticker time.Duration +} + +func newProgressReader(rc io.ReadCloser, label string) io.ReadCloser { + p := &progressReader{ + rc: rc, + label: label, + start: time.Now(), + quit: make(chan struct{}), + done: make(chan struct{}), + ticker: 500 * time.Millisecond, + } + + go func() { + t := time.NewTicker(p.ticker) + defer t.Stop() + var last int64 + for { + select { + case <-t.C: + total := atomic.LoadInt64(&p.total) + elapsed := time.Since(p.start).Seconds() + var rate float64 + if elapsed > 0 { + rate = float64(total) / elapsed + } + // \r to overwrite the same line like git pull; no newline until done + fmt.Fprintf(os.Stderr, "\r%s: %d bytes (%.1f KiB/s)", p.label, total, rate/1024) + last = total + case <-p.quit: + // final line (replace same line, then newline) + total := atomic.LoadInt64(&p.total) + if total == last { + fmt.Fprintf(os.Stderr, "\r%s: %d bytes\n", p.label, total) + } else { + fmt.Fprintf(os.Stderr, "\r%s: %d bytes\n", p.label, total) + } + close(p.done) + return + } + } + }() + + return p +} + +func (p *progressReader) Read(b []byte) (int, error) { + n, err := p.rc.Read(b) + if n > 0 { + atomic.AddInt64(&p.total, int64(n)) + } + return n, err +} + +func (p *progressReader) Close() error { + // Close underlying reader first, then stop progress goroutine and wait for completion. + err := p.rc.Close() + close(p.quit) + <-p.done + return err +} + +// AgentFetchReader fetches the object described by `input` and returns an io.ReadCloser. +// It accepts `s3://bucket/key` URLs and converts them to HTTPS URLs. If `input.AWSEndpoint` +// is set it will use that endpoint in path-style (endpoint/bucket/key); otherwise it +// uses the default virtual-hosted AWS form: https://{bucket}.s3.amazonaws.com/{key}. +func AgentFetchReader(ctx context.Context, input InspectInput) (io.ReadCloser, error) { + if ctx == nil { + ctx = context.Background() + } + + raw := strings.TrimSpace(input.S3URL) + if raw == "" { + return nil, fmt.Errorf("AgentFetchReader: InspectInput.S3URL is empty") + } + + u, err := url.Parse(raw) + if err != nil { + return nil, fmt.Errorf("AgentFetchReader: parse url %q: %w", raw, err) + } + + var s3url string + switch u.Scheme { + case "s3": + bucket := u.Host + key := strings.TrimPrefix(u.Path, "/") + if bucket == "" || key == "" { + return nil, fmt.Errorf("AgentFetchReader: invalid s3 URL %q", raw) + } + if ep := strings.TrimSpace(input.AWSEndpoint); ep != "" { + // ensure endpoint has a scheme + if !strings.HasPrefix(ep, "http://") && !strings.HasPrefix(ep, "https://") { + ep = "https://" + ep + } + s3url = strings.TrimRight(ep, "/") + "/" + bucket + "/" + key + } else { + s3url = fmt.Sprintf("https://%s.s3.amazonaws.com/%s", bucket, key) + } + case "", "http", "https": + // allow bare host/path (no scheme) by assuming https, otherwise use as-is + if u.Scheme == "" { + s3url = "https://" + raw + } else { + s3url = raw + } + default: + return nil, fmt.Errorf("AgentFetchReader: unsupported URL scheme %q", u.Scheme) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, s3url, nil) + if err != nil { + return nil, fmt.Errorf("AgentFetchReader: create request: %w", err) + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, fmt.Errorf("AgentFetchReader: http get %s: %w", s3url, err) + } + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + if _, copyErr := io.Copy(io.Discard, resp.Body); copyErr != nil { + _ = resp.Body.Close() + return nil, fmt.Errorf("AgentFetchReader: unexpected status %d fetching %s; failed to drain body: %w", resp.StatusCode, s3url, copyErr) + } + if closeErr := resp.Body.Close(); closeErr != nil { + return nil, fmt.Errorf("AgentFetchReader: unexpected status %d fetching %s; failed to close body: %w", resp.StatusCode, s3url, closeErr) + } + return nil, fmt.Errorf("AgentFetchReader: unexpected status %d fetching %s", resp.StatusCode, s3url) + } + + if resp.Body == nil { + _ = resp.Body.Close() + return nil, fmt.Errorf("AgentFetchReader: response body is nil for %s", s3url) + } + + // wrap response body with progress reporting that writes to stderr + label := fmt.Sprintf("fetch %s", s3url) + return newProgressReader(resp.Body, label), nil +} diff --git a/cmd/addurl/lfss3/inspect.go b/cmd/addurl/lfss3/inspect.go new file mode 100644 index 00000000..fc69bc3d --- /dev/null +++ b/cmd/addurl/lfss3/inspect.go @@ -0,0 +1,400 @@ +// Package lfss3 provides a small helper for Git-LFS + S3 object introspection. +// +// It: +// 1. determines the effective Git LFS storage root (.git/lfs vs git config lfs.storage) +// 2. derives a working-tree filename from the S3 object key (basename of key) +// 3. performs an S3 HEAD Object to retrieve size and user metadata (sha256 if present) +package lfss3 + +import ( + "bytes" + "context" + "crypto/tls" + "errors" + "fmt" + "net/http" + "net/url" + "os/exec" + "path" + "path/filepath" + "regexp" + "runtime" + "strings" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/s3" +) + +// InspectInput is the drop-in input you requested. +type InspectInput struct { + S3URL string + AWSAccessKey string + AWSSecretKey string + AWSRegion string + AWSEndpoint string // optional: custom endpoint (Ceph/MinIO/etc.) + SHA256 string // optional expected hex (64 chars). Can be "sha256:" or "" + WorktreeName string // optional override of derived worktree name +} + +// InspectResult is what we return. +type InspectResult struct { + // Git/LFS paths + GitCommonDir string // result of: git rev-parse --git-common-dir + LFSRoot string // either lfs.storage (resolved) or /lfs + + // Object identity + Bucket string + Key string + WorktreeName string // basename of Key (filename) + + // HEAD-derived info + SizeBytes int64 + MetaSHA256 string // from user-defined object metadata (if present) + ETag string + LastModTime time.Time +} + +// InspectS3ForLFS does all 3 requested tasks. +func InspectS3ForLFS(ctx context.Context, in InspectInput) (*InspectResult, error) { + if strings.TrimSpace(in.S3URL) == "" { + return nil, errors.New("S3URL is required") + } + if strings.TrimSpace(in.AWSRegion) == "" { + return nil, errors.New("AWSRegion is required") + } + if in.AWSAccessKey == "" || in.AWSSecretKey == "" { + return nil, errors.New("AWSAccessKey and AWSSecretKey are required") + } + + // 1) Determine Git LFS storage root. + gitCommonDir, err := gitRevParseGitCommonDir(ctx) + if err != nil { + return nil, err + } + lfsRoot, err := resolveLFSRoot(ctx, gitCommonDir) + if err != nil { + return nil, err + } + if lfsRoot == "" { + lfsRoot = filepath.Join(gitCommonDir, "lfs") + } + + // 2) Parse S3 URL + derive working tree filename. + bucket, key, err := parseS3URL(in.S3URL) + if err != nil { + return nil, err + } + worktreeName := path.Base(key) + if worktreeName == "." || worktreeName == "/" || worktreeName == "" { + return nil, fmt.Errorf("could not derive worktree name from key %q", key) + } + + // 3) HEAD on S3 to determine size and meta.SHA256. + s3Client, err := newS3Client(ctx, in) + if err != nil { + return nil, err + } + head, err := s3Client.HeadObject(ctx, &s3.HeadObjectInput{ + Bucket: aws.String(bucket), + Key: aws.String(key), + }) + if err != nil { + return nil, fmt.Errorf("s3 HeadObject failed (bucket=%q key=%q): %w", bucket, key, err) + } + + metaSHA := extractSHA256FromMetadata(head.Metadata) + + // Optional: validate provided SHA256 against metadata if both exist. + expected := normalizeSHA256(in.SHA256) + if expected != "" && metaSHA != "" && !strings.EqualFold(expected, metaSHA) { + return nil, fmt.Errorf("sha256 mismatch: expected=%s head.meta=%s", expected, metaSHA) + } + + var lm time.Time + if head.LastModified != nil { + lm = *head.LastModified + } + + out := &InspectResult{ + GitCommonDir: gitCommonDir, + LFSRoot: lfsRoot, + Bucket: bucket, + Key: key, + WorktreeName: worktreeName, + SizeBytes: *head.ContentLength, + MetaSHA256: metaSHA, + ETag: strings.Trim(*head.ETag, `"`), + LastModTime: lm, + } + return out, nil +} + +// +// --- Git helpers --- +// + +func gitRevParseGitCommonDir(ctx context.Context) (string, error) { + out, err := runGit(ctx, "rev-parse", "--git-common-dir") + if err != nil { + return "", fmt.Errorf("git rev-parse --git-common-dir failed: %w", err) + } + dir := strings.TrimSpace(out) + if dir == "" { + return "", errors.New("git rev-parse returned empty --git-common-dir") + } + // If relative, resolve it against current working directory. + if !filepath.IsAbs(dir) { + abs, err := filepath.Abs(dir) + if err == nil { + dir = abs + } + } + return dir, nil +} + +// resolveLFSRoot implements: +// - if `git config --get lfs.storage` is set: use it +// - if relative: resolve relative to GitCommonDir (this is how git-lfs treats it in practice) +// +// - else: /lfs +func resolveLFSRoot(ctx context.Context, gitCommonDir string) (string, error) { + // NOTE: git config --get returns exit status 1 if key not found. + out, err := runGitAllowMissing(ctx, "config", "--get", "lfs.storage") + if err != nil { + return "", fmt.Errorf("git config --get lfs.storage failed: %w", err) + } + val := strings.TrimSpace(out) + + if val == "" { + return filepath.Clean(filepath.Join(gitCommonDir, "lfs")), nil + } + + // Expand ~ if present (nice-to-have). + if strings.HasPrefix(val, "~"+string(filepath.Separator)) || val == "~" { + home, herr := userHomeDir() + if herr == nil && home != "" { + val = filepath.Join(home, strings.TrimPrefix(val, "~")) + } + } + + if !filepath.IsAbs(val) { + val = filepath.Join(gitCommonDir, val) + } + return filepath.Clean(val), nil +} + +func runGit(ctx context.Context, args ...string) (string, error) { + cmd := exec.CommandContext(ctx, "git", args...) + b, err := cmd.CombinedOutput() + if err != nil { + return "", fmt.Errorf("%v: %s", err, strings.TrimSpace(string(b))) + } + return string(b), nil +} + +// runGitAllowMissing treats "key not found" as empty output, not an error. +func runGitAllowMissing(ctx context.Context, args ...string) (string, error) { + cmd := exec.CommandContext(ctx, "git", args...) + b, err := cmd.CombinedOutput() + if err != nil { + // "git config --get missing.key" exits 1 with empty output. + s := strings.TrimSpace(string(b)) + if s == "" { + return "", nil + } + return "", fmt.Errorf("%v: %s", err, s) + } + return string(b), nil +} + +func userHomeDir() (string, error) { + // Avoid os/user on some cross-compile scenarios; keep it simple. + if runtime.GOOS == "windows" { + // Not your target, but safe fallback. + return "", errors.New("home expansion not supported on windows in this helper") + } + // macOS/Linux + out, err := exec.Command("sh", "-lc", "printf %s \"$HOME\"").CombinedOutput() + if err != nil { + return "", err + } + return strings.TrimSpace(string(out)), nil +} + +// +// --- S3 parsing + client --- +// + +var virtualHostedRE = regexp.MustCompile(`^(.+?)\.s3(?:[.-]|$)`) + +// parseS3URL parses s3://bucket/key, virtual-hosted HTTPS (bucket.s3.../key) +// and path-style HTTPS (s3.../bucket/key). Returns bucket and key. +func parseS3URL(raw string) (string, string, error) { + u, err := url.Parse(raw) + if err != nil { + return "", "", err + } + + switch u.Scheme { + case "s3": + bucket := u.Host + key := strings.TrimPrefix(u.Path, "/") + return bucket, key, nil + case "http", "https": + host := u.Hostname() + + // virtual-hosted: bucket.s3.amazonaws.com or bucket.s3-region.amazonaws.com + if m := virtualHostedRE.FindStringSubmatch(host); m != nil { + bucket := m[1] + key := strings.TrimPrefix(u.Path, "/") + return bucket, key, nil + } + + // path-style: s3.../bucket/key + path := strings.TrimPrefix(u.Path, "/") + if path == "" { + return "", "", fmt.Errorf("no bucket in URL: %s", raw) + } + parts := strings.SplitN(path, "/", 2) + bucket := parts[0] + key := "" + if len(parts) == 2 { + key = parts[1] + } + return bucket, key, nil + default: + return "", "", fmt.Errorf("unsupported scheme: %s", u.Scheme) + } +} + +func newS3Client(ctx context.Context, in InspectInput) (*s3.Client, error) { + creds := credentials.NewStaticCredentialsProvider(in.AWSAccessKey, in.AWSSecretKey, "") + + // Custom HTTP client is useful for S3-compatible endpoints. + httpClient := &http.Client{ + Timeout: 60 * time.Second, + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + MinVersion: tls.VersionTLS12, + }, + }, + } + + cfg, err := config.LoadDefaultConfig(ctx, + config.WithRegion(in.AWSRegion), + config.WithCredentialsProvider(creds), + config.WithHTTPClient(httpClient), + ) + if err != nil { + return nil, fmt.Errorf("aws config init failed: %w", err) + } + + opts := []func(*s3.Options){} + if strings.TrimSpace(in.AWSEndpoint) != "" { + ep := strings.TrimRight(in.AWSEndpoint, "/") + opts = append(opts, func(o *s3.Options) { + o.UsePathStyle = true // usually required for Ceph/MinIO/custom endpoints + o.BaseEndpoint = aws.String(ep) + }) + } + + return s3.NewFromConfig(cfg, opts...), nil +} + +// +// --- SHA256 metadata extraction --- +// + +var sha256HexRe = regexp.MustCompile(`(?i)^[0-9a-f]{64}$`) + +func normalizeSHA256(s string) string { + s = strings.TrimSpace(s) + s = strings.TrimPrefix(strings.ToLower(s), "sha256:") + s = strings.TrimSpace(s) + if s == "" { + return "" + } + if !sha256HexRe.MatchString(s) { + // If caller provided something malformed, treat as empty. + // Change this to a hard error if you prefer. + return "" + } + return strings.ToLower(s) +} + +func extractSHA256FromMetadata(md map[string]string) string { + if md == nil { + return "" + } + + // AWS SDK v2 exposes user-defined metadata WITHOUT the "x-amz-meta-" prefix, + // and normalizes keys to lower-case. + candidates := []string{ + "sha256", + "checksum-sha256", + "content-sha256", + "oid-sha256", + "git-lfs-sha256", + } + + for _, k := range candidates { + if v, ok := md[k]; ok { + n := normalizeSHA256(v) + if n != "" { + return n + } + } + } + + // Sometimes people stash "sha256:" + for _, v := range md { + if n := normalizeSHA256(v); n != "" { + return n + } + } + + return "" +} + +// IsLFSTracked returns true if the given path is tracked by Git LFS +// (i.e. has `filter=lfs` via git attributes). +func IsLFSTracked(path string) (bool, error) { + if path == "" { + return false, fmt.Errorf("path is empty") + } + + // Git prefers forward slashes, even on macOS/Linux + cleanPath := filepath.ToSlash(path) + + cmd := exec.Command( + "git", + "check-attr", + "filter", + "--", + cleanPath, + ) + + var out bytes.Buffer + cmd.Stdout = &out + cmd.Stderr = &out + + if err := cmd.Run(); err != nil { + return false, fmt.Errorf("git check-attr failed: %w (%s)", err, out.String()) + } + + // Expected output: + // path: filter: lfs + // path: filter: unspecified + // + // Format is stable and documented. + fields := strings.Split(out.String(), ":") + if len(fields) < 3 { + return false, nil + } + + value := strings.TrimSpace(fields[2]) + return value == "lfs", nil +} diff --git a/cmd/addurl/lfss3/inspect_test.go b/cmd/addurl/lfss3/inspect_test.go new file mode 100644 index 00000000..1a527a37 --- /dev/null +++ b/cmd/addurl/lfss3/inspect_test.go @@ -0,0 +1,263 @@ +package lfss3 + +import ( + "context" + "os" + "os/exec" + "path/filepath" + "runtime" + "strings" + "testing" +) + +func TestParseS3URL_S3Scheme(t *testing.T) { + b, k, err := parseS3URL("s3://my-bucket/path/to/file.bam") + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if b != "my-bucket" { + t.Fatalf("bucket mismatch: %q", b) + } + if k != "path/to/file.bam" { + t.Fatalf("key mismatch: %q", k) + } +} + +func TestParseS3URL_HTTPSPathStyle(t *testing.T) { + b, k, err := parseS3URL("https://s3.example.org/my-bucket/path/to/file.bam") + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if b != "my-bucket" { + t.Fatalf("bucket mismatch: %q", b) + } + if k != "path/to/file.bam" { + t.Fatalf("key mismatch: %q", k) + } +} + +func TestParseS3URL_HTTPSVirtualHosted(t *testing.T) { + b, k, err := parseS3URL("https://my-bucket.s3.amazonaws.com/path/to/file.bam") + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if b != "my-bucket" { + t.Fatalf("bucket mismatch: %q", b) + } + if k != "path/to/file.bam" { + t.Fatalf("key mismatch: %q", k) + } +} + +func TestNormalizeSHA256(t *testing.T) { + hex := "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" + + if got := normalizeSHA256(hex); got != hex { + t.Fatalf("expected %q, got %q", hex, got) + } + + if got := normalizeSHA256("sha256:" + strings.ToUpper(hex)); got != hex { + t.Fatalf("expected %q, got %q", hex, got) + } + + if got := normalizeSHA256("not-a-sha"); got != "" { + t.Fatalf("expected empty for invalid, got %q", got) + } +} + +func TestExtractSHA256FromMetadata_ByKey(t *testing.T) { + hex := "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" + + md := map[string]string{ + "sha256": hex, + } + got := extractSHA256FromMetadata(md) + if got != hex { + t.Fatalf("expected %q, got %q", hex, got) + } +} + +func TestExtractSHA256FromMetadata_ByAlternateKey(t *testing.T) { + hex := "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" + + md := map[string]string{ + "checksum-sha256": "sha256:" + hex, + } + got := extractSHA256FromMetadata(md) + if got != hex { + t.Fatalf("expected %q, got %q", hex, got) + } +} + +func TestExtractSHA256FromMetadata_SearchValues(t *testing.T) { + hex := "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" + + md := map[string]string{ + "something": "sha256:" + hex, + } + got := extractSHA256FromMetadata(md) + if got != hex { + t.Fatalf("expected %q, got %q", hex, got) + } +} + +func TestGitCommonDirAndResolveLFSRoot_Default(t *testing.T) { + if _, err := exec.LookPath("git"); err != nil { + t.Skip("git not found in PATH") + } + + ctx := context.Background() + repo := t.TempDir() + + mustRun(t, repo, "git", "init") + // ensure we're in that repo for git config calls + oldwd := mustChdir(t, repo) + t.Cleanup(func() { _ = os.Chdir(oldwd) }) + + gitCommon, err := gitRevParseGitCommonDir(ctx) + if err != nil { + t.Fatalf("gitRevParseGitCommonDir: %v", err) + } + + lfsRoot, err := resolveLFSRoot(ctx, gitCommon) + if err != nil { + t.Fatalf("resolveLFSRoot: %v", err) + } + + want := filepath.Clean(filepath.Join(gitCommon, "lfs")) + if lfsRoot != want { + t.Fatalf("expected lfsRoot %q, got %q", want, lfsRoot) + } +} + +func TestResolveLFSRoot_ConfigAbsolute(t *testing.T) { + if _, err := exec.LookPath("git"); err != nil { + t.Skip("git not found in PATH") + } + + ctx := context.Background() + repo := t.TempDir() + absStorage := filepath.Join(repo, "custom-lfs-storage") + + mustRun(t, repo, "git", "init") + + oldwd := mustChdir(t, repo) + t.Cleanup(func() { _ = os.Chdir(oldwd) }) + + // set lfs.storage + mustRun(t, repo, "git", "config", "lfs.storage", absStorage) + + gitCommon, err := gitRevParseGitCommonDir(ctx) + if err != nil { + t.Fatalf("gitRevParseGitCommonDir: %v", err) + } + + lfsRoot, err := resolveLFSRoot(ctx, gitCommon) + if err != nil { + t.Fatalf("resolveLFSRoot: %v", err) + } + + want := filepath.Clean(absStorage) + if lfsRoot != want { + t.Fatalf("expected %q, got %q", want, lfsRoot) + } +} + +func TestResolveLFSRoot_ConfigRelative(t *testing.T) { + if _, err := exec.LookPath("git"); err != nil { + t.Skip("git not found in PATH") + } + + ctx := context.Background() + repo := t.TempDir() + mustRun(t, repo, "git", "init") + + oldwd := mustChdir(t, repo) + t.Cleanup(func() { _ = os.Chdir(oldwd) }) + + // relative storage path (resolved under gitCommonDir in our helper) + mustRun(t, repo, "git", "config", "lfs.storage", "rel-lfs") + + gitCommon, err := gitRevParseGitCommonDir(ctx) + if err != nil { + t.Fatalf("gitRevParseGitCommonDir: %v", err) + } + + lfsRoot, err := resolveLFSRoot(ctx, gitCommon) + if err != nil { + t.Fatalf("resolveLFSRoot: %v", err) + } + + want := filepath.Clean(filepath.Join(gitCommon, "rel-lfs")) + if lfsRoot != want { + t.Fatalf("expected %q, got %q", want, lfsRoot) + } +} + +func TestResolveLFSRoot_ConfigTildeExpansion(t *testing.T) { + // This test relies on `sh -lc` in userHomeDir, which we don't run on Windows. + if runtime.GOOS == "windows" { + t.Skip("tilde expansion test skipped on windows") + } + if _, err := exec.LookPath("git"); err != nil { + t.Skip("git not found in PATH") + } + + ctx := context.Background() + repo := t.TempDir() + home := filepath.Join(repo, "fake-home") + if err := os.MkdirAll(home, 0o755); err != nil { + t.Fatalf("mkdir fake home: %v", err) + } + + // Force HOME so userHomeDir() resolves consistently + oldHome := os.Getenv("HOME") + _ = os.Setenv("HOME", home) + t.Cleanup(func() { _ = os.Setenv("HOME", oldHome) }) + + mustRun(t, repo, "git", "init") + + oldwd := mustChdir(t, repo) + t.Cleanup(func() { _ = os.Chdir(oldwd) }) + + mustRun(t, repo, "git", "config", "lfs.storage", "~/lfs-store") + + gitCommon, err := gitRevParseGitCommonDir(ctx) + if err != nil { + t.Fatalf("gitRevParseGitCommonDir: %v", err) + } + + lfsRoot, err := resolveLFSRoot(ctx, gitCommon) + if err != nil { + t.Fatalf("resolveLFSRoot: %v", err) + } + + want := filepath.Clean(filepath.Join(home, "lfs-store")) + if lfsRoot != want { + t.Fatalf("expected %q, got %q", want, lfsRoot) + } +} + +// --- test helpers --- + +func mustRun(t *testing.T, dir string, name string, args ...string) { + t.Helper() + cmd := exec.Command(name, args...) + cmd.Dir = dir + out, err := cmd.CombinedOutput() + if err != nil { + t.Fatalf("command failed: %s %v\nerr=%v\nout=%s", name, args, err, string(out)) + } +} + +func mustChdir(t *testing.T, dir string) string { + t.Helper() + old, err := os.Getwd() + if err != nil { + t.Fatalf("Getwd: %v", err) + } + if err := os.Chdir(dir); err != nil { + t.Fatalf("Chdir(%s): %v", dir, err) + } + return old +} diff --git a/cmd/addurl/lfss3/is_lfs_tracked_test.go b/cmd/addurl/lfss3/is_lfs_tracked_test.go new file mode 100644 index 00000000..f7477b2e --- /dev/null +++ b/cmd/addurl/lfss3/is_lfs_tracked_test.go @@ -0,0 +1,64 @@ +package lfss3 + +import ( + "os" + "os/exec" + "path/filepath" + "testing" +) + +func TestIsLFSTracked(t *testing.T) { + if _, err := exec.LookPath("git"); err != nil { + t.Skip("git not found in PATH") + } + + repo := t.TempDir() + mustRun(t, repo, "git", "init") + + // Add an LFS tracking rule (we only need attributes; git-lfs binary not required) + attr := []byte("*.dat filter=lfs diff=lfs merge=lfs -text\n") + if err := os.WriteFile(filepath.Join(repo, ".gitattributes"), attr, 0o644); err != nil { + t.Fatalf("write .gitattributes: %v", err) + } + + // Create files + tracked := filepath.Join(repo, "data", "file.dat") + untracked := filepath.Join(repo, "data", "file.txt") + if err := os.MkdirAll(filepath.Dir(tracked), 0o755); err != nil { + t.Fatalf("mkdir: %v", err) + } + if err := os.WriteFile(tracked, []byte("x"), 0o644); err != nil { + t.Fatalf("write tracked file: %v", err) + } + if err := os.WriteFile(untracked, []byte("y"), 0o644); err != nil { + t.Fatalf("write untracked file: %v", err) + } + + // Run from inside repo so git check-attr works + oldwd, err := os.Getwd() + if err != nil { + t.Fatalf("getwd: %v", err) + } + if err := os.Chdir(repo); err != nil { + t.Fatalf("chdir: %v", err) + } + t.Cleanup(func() { _ = os.Chdir(oldwd) }) + + // Verify tracked + got, err := IsLFSTracked("data/file.dat") + if err != nil { + t.Fatalf("IsLFSTracked tracked: %v", err) + } + if !got { + t.Fatalf("expected data/file.dat to be LFS tracked") + } + + // Verify untracked + got, err = IsLFSTracked("data/file.txt") + if err != nil { + t.Fatalf("IsLFSTracked untracked: %v", err) + } + if got { + t.Fatalf("expected data/file.txt to NOT be LFS tracked") + } +} diff --git a/cmd/addurl/main.go b/cmd/addurl/main.go index 12427369..d5abf93c 100644 --- a/cmd/addurl/main.go +++ b/cmd/addurl/main.go @@ -1,119 +1,281 @@ package addurl import ( + "context" + "crypto/sha256" "errors" "fmt" + "io" + "net/url" "os" - "os/exec" "path/filepath" + "strings" - "github.com/calypr/git-drs/config" - "github.com/calypr/git-drs/drslog" - "github.com/calypr/git-drs/s3_utils" - "github.com/calypr/git-drs/utils" "github.com/spf13/cobra" + + "github.com/calypr/git-drs/cmd/addurl/lfss3" + "github.com/calypr/git-drs/s3_utils" ) -// AddURLCmd represents the add-url command -var AddURLCmd = &cobra.Command{ - Use: "add-url ", +var Cmd = &cobra.Command{ + Use: "add-url [path]", Short: "Add a file to the Git DRS repo using an S3 URL", Args: func(cmd *cobra.Command, args []string) error { - if len(args) != 2 { - cmd.SilenceUsage = false - return fmt.Errorf("error: requires exactly 2 arguments (S3 URL and SHA256), received %d\n\nUsage: %s\n\nSee 'git drs add-url --help' for more details", len(args), cmd.UseLine()) + if len(args) < 1 || len(args) > 2 { + return errors.New("usage: add-url [path]") } return nil }, - RunE: func(cmd *cobra.Command, args []string) error { - myLogger := drslog.GetLogger() + RunE: runAddURL, +} - // set git config lfs.allowincompletepush = true - configCmd := exec.Command("git", "config", "lfs.allowincompletepush", "true") - if err := configCmd.Run(); err != nil { - return fmt.Errorf("unable to configure git to push pointers: %v. Please change the .git/config file to include an [lfs] section with allowincompletepush = true", err) - } +func init() { + Cmd.Flags().String( + s3_utils.AWS_KEY_FLAG_NAME, + os.Getenv(s3_utils.AWS_KEY_ENV_VAR), + "AWS access key", + ) - // Parse arguments - s3URL := args[0] - sha256 := args[1] - awsAccessKey, _ := cmd.Flags().GetString(s3_utils.AWS_KEY_FLAG_NAME) - awsSecretKey, _ := cmd.Flags().GetString(s3_utils.AWS_SECRET_FLAG_NAME) - awsRegion, _ := cmd.Flags().GetString(s3_utils.AWS_REGION_FLAG_NAME) - awsEndpoint, _ := cmd.Flags().GetString(s3_utils.AWS_ENDPOINT_URL_FLAG_NAME) - remote, _ := cmd.Flags().GetString("remote") - - // if providing credentials, access key and secret must both be provided - if (awsAccessKey == "" && awsSecretKey != "") || (awsAccessKey != "" && awsSecretKey == "") { - return errors.New("incomplete credentials provided as environment variables. Please run `export " + s3_utils.AWS_KEY_ENV_VAR + "=` and `export " + s3_utils.AWS_SECRET_ENV_VAR + "=` to configure both") - } + Cmd.Flags().String( + s3_utils.AWS_SECRET_FLAG_NAME, + os.Getenv(s3_utils.AWS_SECRET_ENV_VAR), + "AWS secret key", + ) - // if none provided, use default AWS configuration on file - if awsAccessKey == "" && awsSecretKey == "" { - myLogger.Debug("No AWS credentials provided. Using default AWS configuration from file.") - } + Cmd.Flags().String( + s3_utils.AWS_REGION_FLAG_NAME, + os.Getenv(s3_utils.AWS_REGION_ENV_VAR), + "AWS S3 region", + ) - cfg, err := config.LoadConfig() - if err != nil { - return fmt.Errorf("error loading config: %v", err) - } + Cmd.Flags().String( + s3_utils.AWS_ENDPOINT_URL_FLAG_NAME, + os.Getenv(s3_utils.AWS_ENDPOINT_URL_ENV_VAR), + "AWS S3 endpoint (optional, for Ceph/MinIO)", + ) - remoteName, err := cfg.GetRemoteOrDefault(remote) - if err != nil { - return fmt.Errorf("error getting default remote: %v", err) - } + // New flag: optional expected SHA256 + Cmd.Flags().String( + "sha256", + "", + "Expected SHA256 checksum (optional)", + ) +} - drsClient, err := cfg.GetRemoteClient(remoteName, myLogger) - if err != nil { - return fmt.Errorf("error getting current remote client: %v", err) - } +func runAddURL(cmd *cobra.Command, args []string) (err error) { + ctx := cmd.Context() + if ctx == nil { + ctx = context.Background() + } + + s3URL := args[0] - // Call client.AddURL to handle Gen3 interactions - meta, err := drsClient.AddURL(s3URL, sha256, awsAccessKey, awsSecretKey, awsRegion, awsEndpoint) - if err != nil { - return err + // Determine path: use provided optional arg, otherwise derive from URL path + var pathArg string + if len(args) == 2 { + pathArg = args[1] + } else { + u, perr := url.Parse(s3URL) + if perr != nil { + return perr } + pathArg = strings.TrimPrefix(u.Path, "/") + } + + sha256Param, ferr := cmd.Flags().GetString("sha256") + if ferr != nil { + return fmt.Errorf("read flag sha256: %w", ferr) + } + + awsKey, ferr := cmd.Flags().GetString(s3_utils.AWS_KEY_FLAG_NAME) + if ferr != nil { + return fmt.Errorf("read flag %s: %w", s3_utils.AWS_KEY_FLAG_NAME, ferr) + } + awsSecret, ferr := cmd.Flags().GetString(s3_utils.AWS_SECRET_FLAG_NAME) + if ferr != nil { + return fmt.Errorf("read flag %s: %w", s3_utils.AWS_SECRET_FLAG_NAME, ferr) + } + awsRegion, ferr := cmd.Flags().GetString(s3_utils.AWS_REGION_FLAG_NAME) + if ferr != nil { + return fmt.Errorf("read flag %s: %w", s3_utils.AWS_REGION_FLAG_NAME, ferr) + } + awsEndpoint, ferr := cmd.Flags().GetString(s3_utils.AWS_ENDPOINT_URL_FLAG_NAME) + if ferr != nil { + return fmt.Errorf("read flag %s: %w", s3_utils.AWS_ENDPOINT_URL_FLAG_NAME, ferr) + } + + if awsKey == "" || awsSecret == "" { + return errors.New("AWS credentials must be provided via flags or environment variables") + } + if awsRegion == "" { + return errors.New("AWS region must be provided via flag or environment variable") + } + + s3Input := lfss3.InspectInput{ + S3URL: s3URL, + AWSAccessKey: awsKey, + AWSSecretKey: awsSecret, + AWSRegion: awsRegion, + AWSEndpoint: awsEndpoint, + SHA256: sha256Param, + WorktreeName: pathArg, + } + info, err := lfss3.InspectS3ForLFS(ctx, s3Input) + if err != nil { + return err + } + + isLFSTracked, err := lfss3.IsLFSTracked(pathArg) + if err != nil { + return fmt.Errorf("check LFS tracking for %s: %w", pathArg, err) + } + + if _, err := fmt.Fprintf(cmd.OutOrStdout(), ` +Resolved Git LFS info +--------------------- +Git common dir : %s +LFS storage : %s + +S3 object +--------- +Bucket : %s +Key : %s +Worktree name : %s +Size (bytes) : %d +SHA256 (meta) : %s +ETag : %s +Last modified : %s - // Generate and add pointer file - _, relFilePath, err := utils.ParseS3URL(s3URL) - if err != nil { - return fmt.Errorf("failed to parse S3 URL: %w", err) +Worktree +------------- +path : %s +tracked by LFS : %v + +`, + info.GitCommonDir, + info.LFSRoot, + info.Bucket, + info.Key, + info.WorktreeName, + info.SizeBytes, + info.MetaSHA256, + info.ETag, + info.LastModTime.Format("2006-01-02T15:04:05Z07:00"), + pathArg, + isLFSTracked, + ); err != nil { + return fmt.Errorf("print resolved info: %w", err) + } + + // 2) object destination + tmpDir := filepath.Join(info.LFSRoot, "tmp-objects", info.ETag[0:2], info.ETag[2:4]) + tmpObj := filepath.Join(tmpDir, info.ETag) + + // 3) fetch bytes -> tmp, compute sha+count + + // replace the pseudocode with this real Go snippet (to be placed inside runAddURL) + if err = os.MkdirAll(tmpDir, 0755); err != nil { + return fmt.Errorf("mkdir %s: %w", tmpDir, err) + } + + f, err := os.Create(tmpObj) + if err != nil { + return fmt.Errorf("create %s: %w", tmpObj, err) + } + // ensure any leftover file is closed and error propagated via named return + defer func() { + if f != nil { + if cerr := f.Close(); cerr != nil && err == nil { + err = fmt.Errorf("close tmp file: %w", cerr) + } } - if err := generatePointerFile(relFilePath, sha256, meta.Size); err != nil { - return fmt.Errorf("failed to generate pointer file: %w", err) + }() + + h := sha256.New() + + var reader io.ReadCloser + reader, err = lfss3.AgentFetchReader(ctx, s3Input) + if err != nil { + return fmt.Errorf("fetch reader: %w", err) + } + // ensure close on any early return; propagate close error via named return + defer func() { + if reader != nil { + if cerr := reader.Close(); cerr != nil && err == nil { + err = fmt.Errorf("close reader: %w", cerr) + } } - myLogger.Debug("S3 URL successfully added to Git DRS repo.") - return nil - }, -} + }() -func init() { - AddURLCmd.Flags().String(s3_utils.AWS_KEY_FLAG_NAME, os.Getenv(s3_utils.AWS_KEY_ENV_VAR), "AWS access key") - AddURLCmd.Flags().String(s3_utils.AWS_SECRET_FLAG_NAME, os.Getenv(s3_utils.AWS_SECRET_ENV_VAR), "AWS secret key") - AddURLCmd.Flags().String(s3_utils.AWS_REGION_FLAG_NAME, os.Getenv(s3_utils.AWS_REGION_ENV_VAR), "AWS S3 region") - AddURLCmd.Flags().String(s3_utils.AWS_ENDPOINT_URL_FLAG_NAME, os.Getenv(s3_utils.AWS_ENDPOINT_URL_ENV_VAR), "AWS S3 endpoint") - AddURLCmd.Flags().String("remote", "", "target remote DRS server (default: default_remote)") -} + n, err := io.Copy(io.MultiWriter(f, h), reader) + if err != nil { + return fmt.Errorf("copy bytes to %s: %w", tmpObj, err) + } + + // explicitly close reader and handle error + if cerr := reader.Close(); cerr != nil { + return fmt.Errorf("close reader: %w", cerr) + } + reader = nil + + // ensure data is flushed to disk + if err = f.Sync(); err != nil { + return fmt.Errorf("sync %s: %w", tmpObj, err) + } + + // explicitly close tmp file before rename + if cerr := f.Close(); cerr != nil { + return fmt.Errorf("close %s: %w", tmpObj, cerr) + } + f = nil -func generatePointerFile(filePath string, sha256 string, fileSize int64) error { - // Define the pointer file content - pointerContent := fmt.Sprintf("version https://git-lfs.github.com/spec/v1\noid sha256:%s\nsize %d\n", sha256, fileSize) + // use n (bytes written) to avoid unused var warnings + _ = n - // Ensure the directory exists - if err := os.MkdirAll(filepath.Dir(filePath), 0755); err != nil { - return fmt.Errorf("failed to create directory for pointer file: %w", err) + // compute hex SHA256 of the fetched content + computedSHA := fmt.Sprintf("%x", h.Sum(nil)) + //optional: compare with provided `sha256` flag if desired + if sha256Param != "" && sha256Param != computedSHA { + return fmt.Errorf("sha256Param mismatch: expected %s got %s", sha256Param, computedSHA) } - // Write the pointer file - if err := os.WriteFile(filePath, []byte(pointerContent), 0644); err != nil { - return fmt.Errorf("failed to write pointer file: %w", err) + oid := computedSHA + dstDir := filepath.Join(info.LFSRoot, "objects", oid[0:2], oid[2:4]) + dstObj := filepath.Join(dstDir, oid) + + if err = os.MkdirAll(dstDir, 0755); err != nil { + return fmt.Errorf("mkdir %s: %w", dstDir, err) } - // Add the pointer file to Git - cmd := exec.Command("git", "add", filePath) - if err := cmd.Run(); err != nil { - return fmt.Errorf("failed to add pointer file to Git: %w", err) + if err = os.Rename(tmpObj, dstObj); err != nil { + return fmt.Errorf("rename %s to %s: %w", tmpObj, dstObj, err) } + // 5) write pointer file in working tree + pointer := fmt.Sprintf( + "version https://git-lfs.github.com/spec/v1\noid sha256:%s\nsize %d\n", + oid, info.SizeBytes, + ) + // write pointer file to working tree pathArg + if pathArg == "" { + return fmt.Errorf("empty worktree path") + } + safePath := filepath.Clean(pathArg) + dir := filepath.Dir(safePath) + if dir != "." && dir != "/" { + if err := os.MkdirAll(dir, 0755); err != nil { + return fmt.Errorf("mkdir %s: %w", dir, err) + } + } + if err = os.WriteFile(safePath, []byte(pointer), 0644); err != nil { + return fmt.Errorf("write %s: %w", safePath, err) + } + + if _, err := fmt.Fprintf(os.Stderr, "Added data file at %s\n", dstObj); err != nil { + return fmt.Errorf("stderr write: %w", err) + } + if _, err := fmt.Fprintf(os.Stderr, "Added Git LFS pointer file at %s\n", safePath); err != nil { + return fmt.Errorf("stderr write: %w", err) + } return nil } diff --git a/cmd/addurl/main_test.go b/cmd/addurl/main_test.go deleted file mode 100644 index 759c01dd..00000000 --- a/cmd/addurl/main_test.go +++ /dev/null @@ -1,28 +0,0 @@ -package addurl - -import ( - "os" - "path/filepath" - "testing" - - "github.com/calypr/git-drs/internal/testutils" -) - -func TestGeneratePointerFile(t *testing.T) { - testutils.SetupTestGitRepo(t) - - path := filepath.Join("data", "file.txt") - err := generatePointerFile(path, "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", 10) - if err != nil { - t.Fatalf("generatePointerFile error: %v", err) - } - - content, err := os.ReadFile(path) - if err != nil { - t.Fatalf("read pointer file: %v", err) - } - - if len(content) == 0 { - t.Fatalf("expected pointer file content") - } -} diff --git a/cmd/root.go b/cmd/root.go index 5bcc71b2..0b4e5184 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -48,7 +48,7 @@ func init() { RootCmd.AddCommand(transfer.Cmd) RootCmd.AddCommand(transferref.Cmd) RootCmd.AddCommand(version.Cmd) - RootCmd.AddCommand(addurl.AddURLCmd) + RootCmd.AddCommand(addurl.Cmd) RootCmd.AddCommand(remote.Cmd) RootCmd.AddCommand(fetch.Cmd) RootCmd.AddCommand(push.Cmd) diff --git a/coverage/combined.html b/coverage/combined.html index 01f6dde0..24b4bde8 100644 --- a/coverage/combined.html +++ b/coverage/combined.html @@ -57,17 +57,17 @@ - + - + - + - + @@ -75,95 +75,99 @@ - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + + + + + @@ -326,10 +330,10 @@ return "" } -func (s AnvilRemote) GetClient(params map[string]string, logger *slog.Logger) (client.DRSClient, error) { +func (s AnvilRemote) GetClient(params map[string]string, logger *slog.Logger) (client.DRSClient, error) { return nil, fmt.Errorf(("AnVIL Client needs to be implemented")) // return NewAnvilClient(s, logger) -} +} @@ -988,21 +992,21 @@ //////////////////// // load repo-level config and return a new IndexDClient -func NewIndexDClient(profileConfig conf.Credential, remote Gen3Remote, logger *slog.Logger) (client.DRSClient, error) { +func NewIndexDClient(profileConfig conf.Credential, remote Gen3Remote, logger *slog.Logger) (client.DRSClient, error) { baseUrl, err := url.Parse(profileConfig.APIEndpoint) // get the gen3Project and gen3Bucket from the config - projectId := remote.GetProjectId() - if projectId == "" { - return nil, fmt.Errorf("no gen3 project specified. Run 'git drs init', use the '--help' flag for more info") + projectId := remote.GetProjectId() + if projectId == "" { + return nil, fmt.Errorf("no gen3 project specified. Run 'git drs init', use the '--help' flag for more info") } - bucketName := remote.GetBucketName() - if bucketName == "" { - logger.Debug("WARNING: no gen3 bucket specified. To add a bucket, run 'git remote add gen3', use the '--help' flag for more info") + bucketName := remote.GetBucketName() + if bucketName == "" { + logger.Debug("WARNING: no gen3 bucket specified. To add a bucket, run 'git remote add gen3', use the '--help' flag for more info") } - transport := &http.Transport{ + transport := &http.Transport{ MaxIdleConns: 100, // Default pool size (across all hosts) MaxIdleConnsPerHost: 100, // Important: Pool size per *single host* (your Indexd server) IdleConnTimeout: 90 * time.Second, @@ -1015,7 +1019,7 @@ retryClient.HTTPClient = httpClient // Custom CheckRetry: do not retry when response body contains "already exists" - retryClient.CheckRetry = func(ctx context.Context, resp *http.Response, err error) (bool, error) { + retryClient.CheckRetry = func(ctx context.Context, resp *http.Response, err error) (bool, error) { if resp != nil && resp.StatusCode < 500 && resp.StatusCode >= 400 { // do not retry on 4xx // 400 => "The request could not be understood by the @@ -1037,7 +1041,7 @@ return retryablehttp.DefaultRetryPolicy(ctx, resp, err) } - retryClient.Logger = drslog.AsStdLogger(logger) + retryClient.Logger = drslog.AsStdLogger(logger) // TODO - make these configurable? retryClient.RetryMax = 5 retryClient.RetryWaitMin = 5 * time.Second @@ -1048,7 +1052,7 @@ return nil, err } - multiPartThresholdInt, err := getLfsCustomTransferInt("lfs.customtransfer.drs.multipart-threshold", 500) + multiPartThresholdInt, err := getLfsCustomTransferInt("lfs.customtransfer.drs.multipart-threshold", 500) var multiPartThreshold int64 = multiPartThresholdInt * common.MB // default 100 MB return &IndexDClient{ @@ -1064,11 +1068,11 @@ }, nil } -func (cl *IndexDClient) GetProjectId() string { +func (cl *IndexDClient) GetProjectId() string { return cl.ProjectId } -func getLfsCustomTransferBool(key string, defaultValue bool) (bool, error) { +func getLfsCustomTransferBool(key string, defaultValue bool) (bool, error) { defaultText := strconv.FormatBool(defaultValue) // TODO cache or get all the configs at once? cmd := exec.Command("git", "config", "--get", "--default", defaultText, key) @@ -1077,16 +1081,16 @@ return defaultValue, fmt.Errorf("error reading git config %s: %v", key, err) } - value := strings.TrimSpace(string(output)) + value := strings.TrimSpace(string(output)) parsed, err := strconv.ParseBool(value) if err != nil { return defaultValue, fmt.Errorf("invalid boolean value for %s: >%q<", key, value) } - return parsed, nil + return parsed, nil } -func getLfsCustomTransferInt(key string, defaultValue int64) (int64, error) { +func getLfsCustomTransferInt(key string, defaultValue int64) (int64, error) { defaultText := strconv.FormatInt(defaultValue, 10) // TODO cache or get all the configs at once? cmd := exec.Command("git", "config", "--get", "--default", defaultText, key) @@ -1095,27 +1099,27 @@ return defaultValue, fmt.Errorf("error reading git config %s: %v", key, err) } - value := strings.TrimSpace(string(output)) + value := strings.TrimSpace(string(output)) parsed, err := strconv.ParseInt(value, 10, 64) if err != nil { return defaultValue, fmt.Errorf("invalid int value for %s: >%q<", key, value) } - if parsed < 1 || parsed > 500 { + if parsed < 1 || parsed > 500 { return defaultValue, fmt.Errorf("invalid int value for %s: %d. Must be between 1 and 500", key, parsed) } - return parsed, nil + return parsed, nil } // GetProfile extracts the profile from the auth handler if available // This is only needed for external APIs like g3cmd that require it -func (cl *IndexDClient) GetProfile() (string, error) { - if rh, ok := cl.AuthHandler.(*RealAuthHandler); ok { +func (cl *IndexDClient) GetProfile() (string, error) { + if rh, ok := cl.AuthHandler.(*RealAuthHandler); ok { return rh.Cred.Profile, nil } - return "", fmt.Errorf("AuthHandler is not RealAuthHandler, cannot extract profile") + return "", fmt.Errorf("AuthHandler is not RealAuthHandler, cannot extract profile") } func (cl *IndexDClient) DeleteRecordsByProject(projectId string) error { @@ -1456,7 +1460,7 @@ return true, nil } -func (cl *IndexDClient) GetObject(id string) (*drs.DRSObject, error) { +func (cl *IndexDClient) GetObject(id string) (*drs.DRSObject, error) { a := *cl.Base a.Path = filepath.Join(a.Path, "ga4gh/drs/v1/objects", id) @@ -1466,30 +1470,30 @@ return nil, err } - err = cl.AuthHandler.AddAuthHeader(req.Request) + err = cl.AuthHandler.AddAuthHeader(req.Request) if err != nil { return nil, fmt.Errorf("error adding Gen3 auth header: %v", err) } - response, err := cl.HttpClient.Do(req) + response, err := cl.HttpClient.Do(req) if err != nil { return nil, err } - defer response.Body.Close() + defer response.Body.Close() if response.Status == "404" { return nil, fmt.Errorf("%s not found", id) } - in := drs.OutputObject{} + in := drs.OutputObject{} if err := cl.SConfig.NewDecoder(response.Body).Decode(&in); err != nil { return nil, err } - return drs.ConvertOutputObjectToDRSObject(&in), nil + return drs.ConvertOutputObjectToDRSObject(&in), nil } -func (cl *IndexDClient) ListObjectsByProject(projectId string) (chan drs.DRSObjectResult, error) { +func (cl *IndexDClient) ListObjectsByProject(projectId string) (chan drs.DRSObjectResult, error) { const PAGESIZE = 50 pageNum := 0 @@ -1499,26 +1503,26 @@ return nil, err } - a := *cl.Base + a := *cl.Base a.Path = filepath.Join(a.Path, "index/index") out := make(chan drs.DRSObjectResult, PAGESIZE) - go func() { + go func() { defer close(out) // This will hold all errors encountered during the loop var resultErrors *multierror.Error active := true - for active { + for active { req, err := retryablehttp.NewRequest("GET", a.String(), nil) if err != nil { resultErrors = multierror.Append(resultErrors, fmt.Errorf("request creation: %w", err)) break } - q := req.URL.Query() + q := req.URL.Query() q.Add("authz", resourcePath) q.Add("limit", fmt.Sprintf("%d", PAGESIZE)) q.Add("page", fmt.Sprintf("%d", pageNum)) @@ -1529,14 +1533,14 @@ break } - response, err := cl.HttpClient.Do(req) + response, err := cl.HttpClient.Do(req) if err != nil { resultErrors = multierror.Append(resultErrors, fmt.Errorf("http call: %w", err)) break } // Read body and close immediately - body, err := io.ReadAll(response.Body) + body, err := io.ReadAll(response.Body) response.Body.Close() if err != nil { @@ -1544,45 +1548,45 @@ break } - if response.StatusCode != http.StatusOK { + if response.StatusCode != http.StatusOK { resultErrors = multierror.Append(resultErrors, fmt.Errorf("api error %d: %s", response.StatusCode, string(body))) break } - page := &ListRecords{} + page := &ListRecords{} if err := cl.SConfig.Unmarshal(body, &page); err != nil { resultErrors = multierror.Append(resultErrors, fmt.Errorf("unmarshal: %w", err)) break } - if len(page.Records) == 0 { + if len(page.Records) == 0 { active = false } - for _, elem := range page.Records { + for _, elem := range page.Records { drsObj, err := elem.ToIndexdRecord().ToDrsObject() if err != nil { // Append and keep going, or break if this is fatal resultErrors = multierror.Append(resultErrors, err) continue } - out <- drs.DRSObjectResult{Object: drsObj} + out <- drs.DRSObjectResult{Object: drsObj} } - pageNum++ + pageNum++ } // If we accumulated any errors, send the final concatenated result - if resultErrors != nil { + if resultErrors != nil { out <- drs.DRSObjectResult{Error: resultErrors.ErrorOrNil()} } }() - return out, nil + return out, nil } // given indexd record, constructs a new indexd record // implements /index/index POST -func (cl *IndexDClient) RegisterIndexdRecord(indexdObj *IndexdRecord) (*drs.DRSObject, error) { +func (cl *IndexDClient) RegisterIndexdRecord(indexdObj *IndexdRecord) (*drs.DRSObject, error) { indexdObjForm := IndexdRecordForm{ IndexdRecord: *indexdObj, Form: "object", @@ -1593,7 +1597,7 @@ return nil, err } - cl.Logger.Debug(fmt.Sprintf("writing IndexdObj: %s", string(jsonBytes))) + cl.Logger.Debug(fmt.Sprintf("writing IndexdObj: %s", string(jsonBytes))) // register DRS object via /index POST // (setup post request to indexd) @@ -1605,7 +1609,7 @@ return nil, err } // set Content-Type header for JSON - req.Header.Set("accept", "application/json") + req.Header.Set("accept", "application/json") req.Header.Set("Content-Type", "application/json") // add auth token @@ -1614,12 +1618,12 @@ return nil, fmt.Errorf("error adding Gen3 auth header: %v", err) } - cl.Logger.Debug(fmt.Sprintf("POST request created for indexd: %s", endpt.String())) + cl.Logger.Debug(fmt.Sprintf("POST request created for indexd: %s", endpt.String())) response, err := cl.HttpClient.Do(req) if err != nil { return nil, err } - defer response.Body.Close() + defer response.Body.Close() // check and see if the response status is OK drsId := indexdObjForm.Did @@ -1627,14 +1631,14 @@ body, _ := io.ReadAll(response.Body) return nil, fmt.Errorf("failed to register DRS ID %s: %s", drsId, body) } - cl.Logger.Debug(fmt.Sprintf("POST successful: %s", response.Status)) + cl.Logger.Debug(fmt.Sprintf("POST successful: %s", response.Status)) // removed re-query return DRS object (was missing access method authorization anyway) drsObj, err := indexdRecordToDrsObject(indexdObj) if err != nil { return nil, fmt.Errorf("error converting indexd record to DRS object: %w %v", err, indexdObj) } - return drsObj, nil + return drsObj, nil } // implements /index{did}?rev={rev} DELETE @@ -1676,7 +1680,7 @@ } // implements /index/index?hash={hashType}:{hash} GET -func (cl *IndexDClient) GetObjectByHash(sum *hash.Checksum) ([]drs.DRSObject, error) { +func (cl *IndexDClient) GetObjectByHash(sum *hash.Checksum) ([]drs.DRSObject, error) { // setup get request to indexd url := fmt.Sprintf("%s/index/index?hash=%s:%s", cl.Base.String(), sum.Type, sum.Checksum) cl.Logger.Debug(fmt.Sprintf("Querying indexd at %s", url)) @@ -1685,20 +1689,20 @@ cl.Logger.Debug(fmt.Sprintf("http.NewRequest Error: %s", err)) return nil, err } - cl.Logger.Debug(fmt.Sprintf("Looking for files with hash %s:%s", sum.Type, sum.Checksum)) + cl.Logger.Debug(fmt.Sprintf("Looking for files with hash %s:%s", sum.Type, sum.Checksum)) err = cl.AuthHandler.AddAuthHeader(req.Request) if err != nil { return nil, fmt.Errorf("unable to add authentication when searching for object: %s:%s. More on the error: %v", sum.Type, sum.Checksum, err) } - req.Header.Set("accept", "application/json") + req.Header.Set("accept", "application/json") // run request and do checks resp, err := cl.HttpClient.Do(req) if err != nil { return nil, fmt.Errorf("unable to check if server has files with hash %s:%s: %v", sum.Type, sum.Checksum, err) } - defer resp.Body.Close() + defer resp.Body.Close() if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) @@ -1706,13 +1710,13 @@ } // unmarshal response body - records := ListRecords{} + records := ListRecords{} err = cl.SConfig.NewDecoder(resp.Body).Decode(&records) if err != nil { return nil, fmt.Errorf("error unmarshaling (%s:%s): %v", sum.Type, sum.Checksum, err) } // log how many records were found - cl.Logger.Debug(fmt.Sprintf("Found %d indexd record(s) matching the hash %v", len(records.Records), records)) + cl.Logger.Debug(fmt.Sprintf("Found %d indexd record(s) matching the hash %v", len(records.Records), records)) out := make([]drs.DRSObject, 0, len(records.Records)) @@ -1721,18 +1725,18 @@ return out, nil } - resourcePath, _ := utils.ProjectToResource(cl.GetProjectId()) + resourcePath, _ := utils.ProjectToResource(cl.GetProjectId()) - for _, record := range records.Records { + for _, record := range records.Records { // skip records that do not authorize this client/project found := false - for _, a := range record.Authz { + for _, a := range record.Authz { if a == resourcePath { found = true break } } - if !found { + if !found { continue } @@ -1743,17 +1747,17 @@ out = append(out, *drsObj) } - return out, nil + return out, nil } // GetProjectSample retrieves a sample of DRS objects for a given project (limit: 1 by default) // Returns up to 'limit' records for preview purposes before destructive operations -func (cl *IndexDClient) GetProjectSample(projectId string, limit int) ([]drs.DRSObject, error) { - if limit <= 0 { +func (cl *IndexDClient) GetProjectSample(projectId string, limit int) ([]drs.DRSObject, error) { + if limit <= 0 { limit = 1 } - cl.Logger.Debug(fmt.Sprintf("Getting sample DRS objects from indexd for project %s (limit: %d)", projectId, limit)) + cl.Logger.Debug(fmt.Sprintf("Getting sample DRS objects from indexd for project %s (limit: %d)", projectId, limit)) // Reuse ListObjectsByProject and collect first 'limit' results objChan, err := cl.ListObjectsByProject(projectId) @@ -1761,30 +1765,30 @@ return nil, err } - result := make([]drs.DRSObject, 0, limit) - for objResult := range objChan { + result := make([]drs.DRSObject, 0, limit) + for objResult := range objChan { if objResult.Error != nil { return nil, objResult.Error } - result = append(result, *objResult.Object) + result = append(result, *objResult.Object) // Stop after collecting enough samples - if len(result) >= limit { + if len(result) >= limit { // Drain remaining results to avoid goroutine leak - go func() { + go func() { for range objChan { } }() - break + break } } - cl.Logger.Debug(fmt.Sprintf("Retrieved %d sample record(s)", len(result))) + cl.Logger.Debug(fmt.Sprintf("Retrieved %d sample record(s)", len(result))) return result, nil } // implements /index/index?authz={resource_path}&start={start}&limit={limit} GET -func (cl *IndexDClient) ListObjects() (chan drs.DRSObjectResult, error) { +func (cl *IndexDClient) ListObjects() (chan drs.DRSObjectResult, error) { cl.Logger.Debug("Getting DRS objects from indexd") @@ -1796,10 +1800,10 @@ LIMIT := 50 pageNum := 0 - go func() { + go func() { defer close(out) active := true - for active { + for active { // setup request req, err := retryablehttp.NewRequest("GET", a.String(), nil) if err != nil { @@ -1808,7 +1812,7 @@ return } - q := req.URL.Query() + q := req.URL.Query() q.Add("limit", fmt.Sprintf("%d", LIMIT)) q.Add("page", fmt.Sprintf("%d", pageNum)) req.URL.RawQuery = q.Encode() @@ -1821,7 +1825,7 @@ } // execute request with error checking - response, err := cl.HttpClient.Do(req) + response, err := cl.HttpClient.Do(req) if err != nil { cl.Logger.Debug(fmt.Sprintf("error: %s", err)) @@ -1829,44 +1833,44 @@ return } - defer response.Body.Close() + defer response.Body.Close() body, err := io.ReadAll(response.Body) if err != nil { cl.Logger.Debug(fmt.Sprintf("error: %s", err)) out <- drs.DRSObjectResult{Error: err} return } - if response.StatusCode != http.StatusOK { + if response.StatusCode != http.StatusOK { cl.Logger.Debug(fmt.Sprintf("%d: check that your credentials are valid \nfull message: %s", response.StatusCode, body)) out <- drs.DRSObjectResult{Error: fmt.Errorf("%d: check your credentials are valid, \nfull message: %s", response.StatusCode, body)} return } // return page of DRS objects - page := &drs.DRSPage{} + page := &drs.DRSPage{} err = cl.SConfig.Unmarshal(body, &page) if err != nil { cl.Logger.Debug(fmt.Sprintf("error: %s (%s)", err, body)) out <- drs.DRSObjectResult{Error: err} return } - for _, elem := range page.DRSObjects { + for _, elem := range page.DRSObjects { out <- drs.DRSObjectResult{Object: &elem} } - if len(page.DRSObjects) == 0 { + if len(page.DRSObjects) == 0 { active = false } - pageNum++ + pageNum++ } - cl.Logger.Debug(fmt.Sprintf("total pages retrieved: %d", pageNum)) + cl.Logger.Debug(fmt.Sprintf("total pages retrieved: %d", pageNum)) }() - return out, nil + return out, nil } // UpdateRecord updates an existing indexd record by GUID using the PUT /index/index/{guid} endpoint // Supports updating: URLs, name (file_name), description (metadata), version, and authz -func (cl *IndexDClient) UpdateRecord(updateInfo *drs.DRSObject, did string) (*drs.DRSObject, error) { +func (cl *IndexDClient) UpdateRecord(updateInfo *drs.DRSObject, did string) (*drs.DRSObject, error) { // Get current revision from existing record record, err := cl.GetIndexdRecordByDID(did) if err != nil { @@ -1874,7 +1878,7 @@ } // Build update payload starting with existing record values - updatePayload := UpdateInputInfo{ + updatePayload := UpdateInputInfo{ URLs: record.URLs, FileName: record.FileName, Version: record.Version, @@ -1885,13 +1889,13 @@ // Apply updates from updateInfo // Update URLs by appending new access methods (deduplicated) - if len(updateInfo.AccessMethods) > 0 { + if len(updateInfo.AccessMethods) > 0 { // Collect new URLs from access methods newURLs := make([]string, 0, len(updateInfo.AccessMethods)) - for _, a := range updateInfo.AccessMethods { + for _, a := range updateInfo.AccessMethods { newURLs = append(newURLs, a.AccessURL.URL) } - updatePayload.URLs = utils.AddUnique(updatePayload.URLs, newURLs) + updatePayload.URLs = utils.AddUnique(updatePayload.URLs, newURLs) // Append authz from access methods (deduplicated) authz := indexdAuthzFromDrsAccessMethods(updateInfo.AccessMethods) @@ -1899,29 +1903,29 @@ } // Update name (maps to file_name in indexd) - if updateInfo.Name != "" { + if updateInfo.Name != "" { updatePayload.FileName = updateInfo.Name } // Update version - if updateInfo.Version != "" { + if updateInfo.Version != "" { updatePayload.Version = updateInfo.Version } // Update description (stored in metadata) - if updateInfo.Description != "" { - if updatePayload.Metadata == nil { + if updateInfo.Description != "" { + if updatePayload.Metadata == nil { updatePayload.Metadata = make(map[string]any) } - updatePayload.Metadata["description"] = updateInfo.Description + updatePayload.Metadata["description"] = updateInfo.Description } - jsonBytes, err := cl.SConfig.Marshal(updatePayload) + jsonBytes, err := cl.SConfig.Marshal(updatePayload) if err != nil { return nil, fmt.Errorf("error marshaling indexd object form: %v", err) } - cl.Logger.Debug(fmt.Sprintf("Prepared updated indexd object for DID %s: %s", did, string(jsonBytes))) + cl.Logger.Debug(fmt.Sprintf("Prepared updated indexd object for DID %s: %s", did, string(jsonBytes))) // prepare URL updateURL := fmt.Sprintf("%s/index/index/%s?rev=%s", cl.Base.String(), did, record.Rev) @@ -1932,7 +1936,7 @@ } // Set required headers - req.Header.Set("accept", "application/json") + req.Header.Set("accept", "application/json") req.Header.Set("Content-Type", "application/json") err = cl.AuthHandler.AddAuthHeader(req.Request) @@ -1940,14 +1944,14 @@ return nil, fmt.Errorf("error adding Gen3 auth header: %v", err) } - cl.Logger.Debug(fmt.Sprintf("PUT request created for indexd update: %s", updateURL)) + cl.Logger.Debug(fmt.Sprintf("PUT request created for indexd update: %s", updateURL)) // Execute the request response, err := cl.HttpClient.Do(req) if err != nil { return nil, fmt.Errorf("error executing PUT request: %v", err) } - defer response.Body.Close() + defer response.Body.Close() // Check response status if response.StatusCode != http.StatusOK { @@ -1955,7 +1959,7 @@ return nil, fmt.Errorf("failed to update indexd record %s: status %d, body: %s", did, response.StatusCode, string(body)) } - cl.Logger.Debug(fmt.Sprintf("PUT request successful: %s", response.Status)) + cl.Logger.Debug(fmt.Sprintf("PUT request successful: %s", response.Status)) // Query and return the updated DRS object updatedDrsObj, err := cl.GetObject(did) @@ -1963,12 +1967,12 @@ return nil, fmt.Errorf("error retrieving updated DRS object: %v", err) } - cl.Logger.Debug(fmt.Sprintf("Successfully updated and retrieved DRS object: %s", did)) + cl.Logger.Debug(fmt.Sprintf("Successfully updated and retrieved DRS object: %s", did)) return updatedDrsObj, nil } // Helper function to get indexd record by DID (similar to existing pattern in DeleteIndexdRecord) -func (cl *IndexDClient) GetIndexdRecordByDID(did string) (*OutputInfo, error) { +func (cl *IndexDClient) GetIndexdRecordByDID(did string) (*OutputInfo, error) { url := fmt.Sprintf("%s/index/%s", cl.Base.String(), did) req, err := retryablehttp.NewRequest("GET", url, nil) @@ -1976,45 +1980,45 @@ return nil, err } - err = cl.AuthHandler.AddAuthHeader(req.Request) + err = cl.AuthHandler.AddAuthHeader(req.Request) if err != nil { return nil, fmt.Errorf("error adding Gen3 auth header: %v", err) } - req.Request.Header.Set("accept", "application/json") + req.Request.Header.Set("accept", "application/json") resp, err := cl.HttpClient.Do(req) if err != nil { return nil, err } - defer resp.Body.Close() + defer resp.Body.Close() if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) return nil, fmt.Errorf("failed to get record: status %d, body: %s", resp.StatusCode, string(body)) } - record := &OutputInfo{} + record := &OutputInfo{} if err := cl.SConfig.NewDecoder(resp.Body).Decode(record); err != nil { return nil, fmt.Errorf("error decoding response body: %v", err) } - return record, nil + return record, nil } -func (cl *IndexDClient) BuildDrsObj(fileName string, checksum string, size int64, drsId string) (*drs.DRSObject, error) { +func (cl *IndexDClient) BuildDrsObj(fileName string, checksum string, size int64, drsId string) (*drs.DRSObject, error) { bucket := cl.BucketName - if bucket == "" { + if bucket == "" { return nil, fmt.Errorf("error: bucket name is empty in config file") } //TODO: support other storage backends - fileURL := fmt.Sprintf("s3://%s", filepath.Join(bucket, drsId, checksum)) + fileURL := fmt.Sprintf("s3://%s", filepath.Join(bucket, drsId, checksum)) authzStr, err := utils.ProjectToResource(cl.GetProjectId()) if err != nil { return nil, err } - authorizations := drs.Authorizations{ + authorizations := drs.Authorizations{ Value: authzStr, } @@ -2333,128 +2337,855 @@ } -