Skip to content

Commit 74b0e81

Browse files
committed
feat: add cache for credential to reduce the probability that kpm would be considered a threat
Signed-off-by: zongz <zongzhe1024@163.com>
1 parent b65d5ed commit 74b0e81

File tree

8 files changed

+291
-40
lines changed

8 files changed

+291
-40
lines changed

pkg/client/client.go

Lines changed: 135 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package client
22

33
import (
4+
"context"
45
"encoding/json"
56
"fmt"
67
"io"
@@ -19,7 +20,9 @@ import (
1920
"github.com/otiai10/copy"
2021
"golang.org/x/mod/module"
2122
"kcl-lang.io/kcl-go/pkg/kcl"
23+
"oras.land/oras-go/pkg/auth"
2224
"oras.land/oras-go/v2"
25+
remoteauth "oras.land/oras-go/v2/registry/remote/auth"
2326

2427
"kcl-lang.io/kpm/pkg/constants"
2528
"kcl-lang.io/kpm/pkg/downloader"
@@ -41,6 +44,8 @@ type KpmClient struct {
4144
logWriter io.Writer
4245
// The downloader of the dependencies.
4346
DepDownloader *downloader.DepDownloader
47+
// credential store
48+
credsClient *downloader.CredClient
4449
// The home path of kpm for global configuration file and kcl package storage path.
4550
homePath string
4651
// The settings of kpm loaded from the global configuration file.
@@ -75,6 +80,33 @@ func (c *KpmClient) SetNoSumCheck(noSumCheck bool) {
7580
c.noSumCheck = noSumCheck
7681
}
7782

83+
// GetCredsClient will return the credential client.
84+
func (c *KpmClient) GetCredsClient() (*downloader.CredClient, error) {
85+
if c.credsClient == nil {
86+
credCli, err := downloader.LoadCredentialFile(c.settings.CredentialsFile)
87+
if err != nil {
88+
return nil, err
89+
}
90+
c.credsClient = credCli
91+
}
92+
return c.credsClient, nil
93+
}
94+
95+
// GetCredentials will return the credentials of the host.
96+
func (c *KpmClient) GetCredentials(hostName string) (*remoteauth.Credential, error) {
97+
credCli, err := c.GetCredsClient()
98+
if err != nil {
99+
return nil, err
100+
}
101+
102+
creds, err := credCli.Credential(hostName)
103+
if err != nil {
104+
return nil, err
105+
}
106+
107+
return creds, nil
108+
}
109+
78110
// GetNoSumCheck will return the 'noSumCheck' flag.
79111
func (c *KpmClient) GetNoSumCheck() bool {
80112
return c.noSumCheck
@@ -953,7 +985,18 @@ func (c *KpmClient) FillDependenciesInfo(modFile *pkg.ModFile) error {
953985

954986
// AcquireTheLatestOciVersion will acquire the latest version of the OCI reference.
955987
func (c *KpmClient) AcquireTheLatestOciVersion(ociSource downloader.Oci) (string, error) {
956-
ociClient, err := oci.NewOciClient(ociSource.Reg, ociSource.Repo, &c.settings)
988+
repoPath := utils.JoinPath(ociSource.Reg, ociSource.Repo)
989+
cred, err := c.GetCredentials(ociSource.Reg)
990+
if err != nil {
991+
return "", err
992+
}
993+
994+
ociClient, err := oci.NewOciClientWithOpts(
995+
oci.WithCredential(cred),
996+
oci.WithRepoPath(repoPath),
997+
oci.WithPlainHttp(c.GetSettings().DefaultOciPlainHttp()),
998+
)
999+
9571000
if err != nil {
9581001
return "", err
9591002
}
@@ -1098,11 +1141,16 @@ func (c *KpmClient) Download(dep *pkg.Dependency, homePath, localPath string) (*
10981141
// clean the temp dir.
10991142
defer os.RemoveAll(tmpDir)
11001143

1144+
credCli, err := c.GetCredsClient()
1145+
if err != nil {
1146+
return nil, err
1147+
}
11011148
err = c.DepDownloader.Download(*downloader.NewDownloadOptions(
11021149
downloader.WithLocalPath(tmpDir),
11031150
downloader.WithSource(dep.Source),
11041151
downloader.WithLogWriter(c.logWriter),
11051152
downloader.WithSettings(c.settings),
1153+
downloader.WithCredsClient(credCli),
11061154
))
11071155
if err != nil {
11081156
return nil, err
@@ -1276,10 +1324,22 @@ func (c *KpmClient) ParseKclModFile(kclPkg *pkg.KclPkg) (map[string]map[string]s
12761324

12771325
// LoadPkgFromOci will download the kcl package from the oci repository and return an `KclPkg`.
12781326
func (c *KpmClient) DownloadPkgFromOci(dep *downloader.Oci, localPath string) (*pkg.KclPkg, error) {
1279-
ociClient, err := oci.NewOciClient(dep.Reg, dep.Repo, &c.settings)
1327+
repoPath := utils.JoinPath(dep.Reg, dep.Repo)
1328+
cred, err := c.GetCredentials(dep.Reg)
12801329
if err != nil {
12811330
return nil, err
12821331
}
1332+
1333+
ociClient, err := oci.NewOciClientWithOpts(
1334+
oci.WithCredential(cred),
1335+
oci.WithRepoPath(repoPath),
1336+
oci.WithPlainHttp(c.GetSettings().DefaultOciPlainHttp()),
1337+
)
1338+
1339+
if err != nil {
1340+
return nil, err
1341+
}
1342+
12831343
ociClient.SetLogWriter(c.logWriter)
12841344
// Select the latest tag, if the tag, the user inputed, is empty.
12851345
var tagSelected string
@@ -1478,7 +1538,18 @@ func (c *KpmClient) PullFromOci(localPath, source, tag string) error {
14781538

14791539
// PushToOci will push a kcl package to oci registry.
14801540
func (c *KpmClient) PushToOci(localPath string, ociOpts *opt.OciOptions) error {
1481-
ociCli, err := oci.NewOciClient(ociOpts.Reg, ociOpts.Repo, &c.settings)
1541+
repoPath := utils.JoinPath(ociOpts.Reg, ociOpts.Repo)
1542+
cred, err := c.GetCredentials(ociOpts.Reg)
1543+
if err != nil {
1544+
return err
1545+
}
1546+
1547+
ociCli, err := oci.NewOciClientWithOpts(
1548+
oci.WithCredential(cred),
1549+
oci.WithRepoPath(repoPath),
1550+
oci.WithPlainHttp(c.GetSettings().DefaultOciPlainHttp()),
1551+
)
1552+
14821553
if err != nil {
14831554
return err
14841555
}
@@ -1504,12 +1575,46 @@ func (c *KpmClient) PushToOci(localPath string, ociOpts *opt.OciOptions) error {
15041575

15051576
// LoginOci will login to the oci registry.
15061577
func (c *KpmClient) LoginOci(hostname, username, password string) error {
1507-
return oci.Login(hostname, username, password, &c.settings)
1578+
1579+
credCli, err := c.GetCredsClient()
1580+
if err != nil {
1581+
return err
1582+
}
1583+
1584+
err = credCli.GetAuthClient().LoginWithOpts(
1585+
[]auth.LoginOption{
1586+
auth.WithLoginHostname(hostname),
1587+
auth.WithLoginUsername(username),
1588+
auth.WithLoginSecret(password),
1589+
}...,
1590+
)
1591+
1592+
if err != nil {
1593+
return reporter.NewErrorEvent(
1594+
reporter.FailedLogin,
1595+
err,
1596+
fmt.Sprintf("failed to login '%s', please check registry, username and password is valid", hostname),
1597+
)
1598+
}
1599+
1600+
return nil
15081601
}
15091602

15101603
// LogoutOci will logout from the oci registry.
15111604
func (c *KpmClient) LogoutOci(hostname string) error {
1512-
return oci.Logout(hostname, &c.settings)
1605+
1606+
credCli, err := c.GetCredsClient()
1607+
if err != nil {
1608+
return err
1609+
}
1610+
1611+
err = credCli.GetAuthClient().Logout(context.Background(), hostname)
1612+
1613+
if err != nil {
1614+
return reporter.NewErrorEvent(reporter.FailedLogout, err, fmt.Sprintf("failed to logout '%s'", hostname))
1615+
}
1616+
1617+
return nil
15131618
}
15141619

15151620
// ParseOciRef will parser '<repo_name>:<repo_tag>' into an 'OciOptions'.
@@ -1753,7 +1858,18 @@ func (c *KpmClient) pullTarFromOci(localPath string, ociOpts *opt.OciOptions) er
17531858
return reporter.NewErrorEvent(reporter.Bug, err)
17541859
}
17551860

1756-
ociCli, err := oci.NewOciClient(ociOpts.Reg, ociOpts.Repo, &c.settings)
1861+
repoPath := utils.JoinPath(ociOpts.Reg, ociOpts.Repo)
1862+
cred, err := c.GetCredentials(ociOpts.Reg)
1863+
if err != nil {
1864+
return err
1865+
}
1866+
1867+
ociCli, err := oci.NewOciClientWithOpts(
1868+
oci.WithCredential(cred),
1869+
oci.WithRepoPath(repoPath),
1870+
oci.WithPlainHttp(c.GetSettings().DefaultOciPlainHttp()),
1871+
)
1872+
17571873
if err != nil {
17581874
return err
17591875
}
@@ -1790,7 +1906,19 @@ func (c *KpmClient) pullTarFromOci(localPath string, ociOpts *opt.OciOptions) er
17901906

17911907
// FetchOciManifestConfIntoJsonStr will fetch the oci manifest config of the kcl package from the oci registry and return it into json string.
17921908
func (c *KpmClient) FetchOciManifestIntoJsonStr(opts opt.OciFetchOptions) (string, error) {
1793-
ociCli, err := oci.NewOciClient(opts.Reg, opts.Repo, &c.settings)
1909+
1910+
repoPath := utils.JoinPath(opts.Reg, opts.Repo)
1911+
cred, err := c.GetCredentials(opts.Reg)
1912+
if err != nil {
1913+
return "", err
1914+
}
1915+
1916+
ociCli, err := oci.NewOciClientWithOpts(
1917+
oci.WithCredential(cred),
1918+
oci.WithRepoPath(repoPath),
1919+
oci.WithPlainHttp(c.GetSettings().DefaultOciPlainHttp()),
1920+
)
1921+
17941922
if err != nil {
17951923
return "", err
17961924
}

pkg/client/visitor.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,12 +150,18 @@ func (rv *RemoteVisitor) Visit(s *downloader.Source, v visitFunc) error {
150150
tmpDir = filepath.Join(tmpDir, constants.GitScheme)
151151
}
152152

153+
credCli, err := rv.kpmcli.GetCredsClient()
154+
if err != nil {
155+
return err
156+
}
157+
153158
defer os.RemoveAll(tmpDir)
154159
err = rv.kpmcli.DepDownloader.Download(*downloader.NewDownloadOptions(
155160
downloader.WithLocalPath(tmpDir),
156161
downloader.WithSource(*s),
157162
downloader.WithLogWriter(rv.kpmcli.GetLogWriter()),
158163
downloader.WithSettings(*rv.kpmcli.GetSettings()),
164+
downloader.WithCredsClient(credCli),
159165
))
160166

161167
if err != nil {

pkg/downloader/credential.go

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
package downloader
2+
3+
import (
4+
"fmt"
5+
6+
dockerauth "oras.land/oras-go/pkg/auth/docker"
7+
remoteauth "oras.land/oras-go/v2/registry/remote/auth"
8+
)
9+
10+
// CredClient is the client to get the credentials.
11+
type CredClient struct {
12+
credsClient *dockerauth.Client
13+
}
14+
15+
// LoadCredentialFile loads the credential file and return the CredClient.
16+
func LoadCredentialFile(filepath string) (*CredClient, error) {
17+
authClient, err := dockerauth.NewClientWithDockerFallback(filepath)
18+
if err != nil {
19+
return nil, err
20+
}
21+
dockerAuthClient, ok := authClient.(*dockerauth.Client)
22+
if !ok {
23+
return nil, fmt.Errorf("authClient is not *docker.Client type")
24+
}
25+
26+
return &CredClient{
27+
credsClient: dockerAuthClient,
28+
}, nil
29+
}
30+
31+
// GetAuthClient returns the auth client.
32+
func (cred *CredClient) GetAuthClient() *dockerauth.Client {
33+
return cred.credsClient
34+
}
35+
36+
// Credential will reture the credential info cache in CredClient
37+
func (cred *CredClient) Credential(hostName string) (*remoteauth.Credential, error) {
38+
if len(hostName) == 0 {
39+
return nil, fmt.Errorf("hostName is empty")
40+
}
41+
username, password, err := cred.credsClient.Credential(hostName)
42+
if err != nil {
43+
return nil, err
44+
}
45+
46+
return &remoteauth.Credential{
47+
Username: username,
48+
Password: password,
49+
}, nil
50+
}

pkg/downloader/downloader.go

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"kcl-lang.io/kpm/pkg/reporter"
1414
"kcl-lang.io/kpm/pkg/settings"
1515
"kcl-lang.io/kpm/pkg/utils"
16+
remoteauth "oras.land/oras-go/v2/registry/remote/auth"
1617
)
1718

1819
// DownloadOptions is the options for downloading a package.
@@ -25,10 +26,18 @@ type DownloadOptions struct {
2526
Settings settings.Settings
2627
// LogWriter is the writer to write the log.
2728
LogWriter io.Writer
29+
// credsClient is the client to get the credentials.
30+
credsClient *CredClient
2831
}
2932

3033
type Option func(*DownloadOptions)
3134

35+
func WithCredsClient(credsClient *CredClient) Option {
36+
return func(do *DownloadOptions) {
37+
do.credsClient = credsClient
38+
}
39+
}
40+
3241
func WithLogWriter(logWriter io.Writer) Option {
3342
return func(do *DownloadOptions) {
3443
do.LogWriter = logWriter
@@ -125,7 +134,25 @@ func (d *OciDownloader) Download(opts DownloadOptions) error {
125134

126135
localPath := opts.LocalPath
127136

128-
ociCli, err := oci.NewOciClient(ociSource.Reg, ociSource.Repo, &opts.Settings)
137+
repoPath := utils.JoinPath(ociSource.Reg, ociSource.Repo)
138+
139+
var cred *remoteauth.Credential
140+
var err error
141+
if opts.credsClient != nil {
142+
cred, err = opts.credsClient.Credential(ociSource.Reg)
143+
if err != nil {
144+
return err
145+
}
146+
} else {
147+
cred = &remoteauth.Credential{}
148+
}
149+
150+
ociCli, err := oci.NewOciClientWithOpts(
151+
oci.WithCredential(cred),
152+
oci.WithRepoPath(repoPath),
153+
oci.WithPlainHttp(opts.Settings.DefaultOciPlainHttp()),
154+
)
155+
129156
if err != nil {
130157
return err
131158
}

0 commit comments

Comments
 (0)