Skip to content

Commit

Permalink
Change pagination logic for /artifacts/list endpoint (G-Research#243)
Browse files Browse the repository at this point in the history
  • Loading branch information
dsuhinin authored Aug 16, 2023
1 parent c412caa commit f7f8d0c
Show file tree
Hide file tree
Showing 10 changed files with 48 additions and 57 deletions.
1 change: 0 additions & 1 deletion pkg/api/mlflow/api/request/artifact.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package request
// ListArtifactsRequest is a request object for `GET /mlflow/artifacts/list` endpoint.
type ListArtifactsRequest struct {
Path string `query:"path"`
Token string `query:"token"`
RunID string `query:"run_id"`
RunUUID string `query:"run_uuid"`
}
Expand Down
12 changes: 5 additions & 7 deletions pkg/api/mlflow/api/response/artifact.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,17 @@ type FilePartialResponse struct {

// ListArtifactsResponse is a response object for `GET mlflow/artifacts/list` endpoint.
type ListArtifactsResponse struct {
Files []FilePartialResponse `json:"files"`
RootURI string `json:"root_uri"`
NextPageToken string `json:"next_page_token,omitempty"`
Files []FilePartialResponse `json:"files"`
RootURI string `json:"root_uri"`
}

// NewListArtifactsResponse creates new instance of ListArtifactsResponse.
func NewListArtifactsResponse(
nextPageToken, rootURI string, artifacts []storage.ArtifactObject,
rootURI string, artifacts []storage.ArtifactObject,
) *ListArtifactsResponse {
response := ListArtifactsResponse{
Files: make([]FilePartialResponse, len(artifacts)),
RootURI: rootURI,
NextPageToken: nextPageToken,
Files: make([]FilePartialResponse, len(artifacts)),
RootURI: rootURI,
}

for i, artifact := range artifacts {
Expand Down
5 changes: 2 additions & 3 deletions pkg/api/mlflow/api/response/artifact_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
)

func TestNewListArtifactsResponse_Ok(t *testing.T) {
response := NewListArtifactsResponse("pageToken", "rootUri", []storage.ArtifactObject{
response := NewListArtifactsResponse("rootUri", []storage.ArtifactObject{
{
Path: "path1",
Size: 1234567890,
Expand All @@ -35,7 +35,6 @@ func TestNewListArtifactsResponse_Ok(t *testing.T) {
FileSize: 0,
},
},
RootURI: "rootUri",
NextPageToken: "pageToken",
RootURI: "rootUri",
}, response)
}
4 changes: 2 additions & 2 deletions pkg/api/mlflow/controller/artifact.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@ func (c Controller) ListArtifacts(ctx *fiber.Ctx) error {
}
log.Debugf("listArtifacts request: %#v", req)

nextPageToken, rootURI, artifacts, err := c.artifactService.ListArtifacts(ctx.Context(), &req)
rootURI, artifacts, err := c.artifactService.ListArtifacts(ctx.Context(), &req)
if err != nil {
return err
}

resp := response.NewListArtifactsResponse(nextPageToken, rootURI, artifacts)
resp := response.NewListArtifactsResponse(rootURI, artifacts)
log.Debugf("artifactList response: %#v", resp)
return ctx.JSON(resp)
}
14 changes: 7 additions & 7 deletions pkg/api/mlflow/service/artifact/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,22 +26,22 @@ func NewService(artifactStorage storage.Provider, runRepository repositories.Run
// ListArtifacts handles business logic of `GET /artifacts/list` endpoint.
func (s Service) ListArtifacts(
ctx context.Context, req *request.ListArtifactsRequest,
) (string, string, []storage.ArtifactObject, error) {
) (string, []storage.ArtifactObject, error) {
if err := ValidateListArtifactsRequest(req); err != nil {
return "", "", nil, err
return "", nil, err
}

run, err := s.runRepository.GetByID(ctx, req.GetRunID())
if err != nil {
return "", "", nil, api.NewInternalError("unable to get artifact URI for run '%s'", req.GetRunID())
return "", nil, api.NewInternalError("unable to get artifact URI for run '%s'", req.GetRunID())
}

nextPageToken, rootURI, artifacts, err := s.artifactStorage.List(
run.ArtifactURI, req.Path, req.Token,
rootURI, artifacts, err := s.artifactStorage.List(
run.ArtifactURI, req.Path,
)
if err != nil {
return "", "", nil, api.NewInternalError("error getting artifact list from storage")
return "", nil, api.NewInternalError("error getting artifact list from storage")
}

return nextPageToken, rootURI, artifacts, nil
return rootURI, artifacts, nil
}
12 changes: 5 additions & 7 deletions pkg/api/mlflow/service/artifact/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,8 @@ import (
func TestService_ListArtifacts_Ok(t *testing.T) {
artifactStorage := storage.MockProvider{}
artifactStorage.On(
"List", "/artifact/uri", "", "",
"List", "/artifact/uri", "",
).Return(
"nextPageToken",
"/root/uri/",
[]storage.ArtifactObject{
{
Expand Down Expand Up @@ -49,15 +48,14 @@ func TestService_ListArtifacts_Ok(t *testing.T) {

// call service under testing.
service := NewService(&artifactStorage, &runRepository)
nextPageToken, rootURI, artifacts, err := service.ListArtifacts(
rootURI, artifacts, err := service.ListArtifacts(
context.TODO(),
&request.ListArtifactsRequest{
RunID: "id",
},
)

assert.Nil(t, err)
assert.Equal(t, "nextPageToken", nextPageToken)
assert.Equal(t, "/root/uri/", rootURI)
assert.Equal(t, []storage.ArtifactObject{
{
Expand Down Expand Up @@ -133,9 +131,9 @@ func TestService_ListArtifacts_Error(t *testing.T) {
service: func() *Service {
artifactStorage := storage.MockProvider{}
artifactStorage.On(
"List", "/artifact/uri", "", "",
"List", "/artifact/uri", "",
).Return(
"", "", nil, errors.New("storage error"),
"", nil, errors.New("storage error"),
)

runRepository := repositories.MockRunRepositoryProvider{}
Expand All @@ -158,7 +156,7 @@ func TestService_ListArtifacts_Error(t *testing.T) {
for _, tt := range testData {
t.Run(tt.name, func(t *testing.T) {
// call service under testing.
_, _, _, err := tt.service().ListArtifacts(context.TODO(), tt.request)
_, _, err := tt.service().ListArtifacts(context.TODO(), tt.request)
assert.Equal(t, tt.error, err)
})
}
Expand Down
10 changes: 5 additions & 5 deletions pkg/api/mlflow/service/artifact/storage/local.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,31 +24,31 @@ func NewLocal(config *config.ServiceConfig) (*Local, error) {
}

// List implements Provider interface.
func (s Local) List(artifactURI, path, _ string) (string, string, []ArtifactObject, error) {
func (s Local) List(artifactURI, path string) (string, []ArtifactObject, error) {
// 1. process search `prefix` parameter.
path, err := url.JoinPath(artifactURI, path)
if err != nil {
return "", "", nil, eris.Wrap(err, "error constructing full path")
return "", nil, eris.Wrap(err, "error constructing full path")
}

// 2. read data from local storage.
objects, err := os.ReadDir(path)
if err != nil {
return "", "", nil, eris.Wrapf(err, "error reading object from local storage")
return "", nil, eris.Wrapf(err, "error reading object from local storage")
}

log.Debugf("got %d objects from local storage for path: %s", len(objects), path)
artifactList := make([]ArtifactObject, len(objects))
for i, object := range objects {
info, err := object.Info()
if err != nil {
return "", "", nil, eris.Wrapf(err, "error getting info for object: %s", object.Name())
return "", nil, eris.Wrapf(err, "error getting info for object: %s", object.Name())
}
artifactList[i] = ArtifactObject{
Path: filepath.Join(path, info.Name()),
Size: info.Size(),
IsDir: object.IsDir(),
}
}
return "", s.config.ArtifactRoot, artifactList, nil
return s.config.ArtifactRoot, artifactList, nil
}
4 changes: 2 additions & 2 deletions pkg/api/mlflow/service/artifact/storage/noop.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,6 @@ func NewNoop() *Noop {
}

// List implements Provider interface.
func (s Noop) List(artifactURI, path, nextPageToken string) (string, string, []ArtifactObject, error) {
return "", "", make([]ArtifactObject, 0), nil
func (s Noop) List(_, _ string) (string, []ArtifactObject, error) {
return "", make([]ArtifactObject, 0), nil
}
41 changes: 19 additions & 22 deletions pkg/api/mlflow/service/artifact/storage/s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,10 @@ func NewS3(config *config.ServiceConfig) (*S3, error) {
}

// List implements Provider interface.
func (s S3) List(artifactURI, path, nextPageToken string) (string, string, []ArtifactObject, error) {
func (s S3) List(artifactURI, path string) (string, []ArtifactObject, error) {
bucket, prefix, err := ExtractS3BucketAndPrefix(artifactURI)
if err != nil {
return "", "", nil, eris.Wrap(err, "error extracting bucket and prefix from provided uri")
return "", nil, eris.Wrap(err, "error extracting bucket and prefix from provided uri")
}
input := s3.ListObjectsV2Input{
Bucket: aws.String(bucket),
Expand All @@ -69,33 +69,30 @@ func (s S3) List(artifactURI, path, nextPageToken string) (string, string, []Art
// 1. process search `prefix` parameter.
path, err = url.JoinPath(*input.Prefix, path)
if err != nil {
return "", "", nil, eris.Wrap(err, "error constructing s3 prefix")
return "", nil, eris.Wrap(err, "error constructing s3 prefix")
}
input.Prefix = aws.String(path)

// 2. process search `nextPageToken` parameter.
if nextPageToken != "" {
input.ContinuationToken = aws.String(nextPageToken)
}

output, err := s.client.ListObjectsV2(context.TODO(), &input)
paginator := s3.NewListObjectsV2Paginator(s.client, &input)
if err != nil {
return "", "", nil, eris.Wrap(err, "error getting s3 objects")
return "", nil, eris.Wrap(err, "error creating s3 paginated request")
}

log.Debugf("got %d objects from S3 storage for path: %s", len(output.Contents), path)
artifactList := make([]ArtifactObject, len(output.Contents))
for i, object := range output.Contents {
artifactList[i] = ArtifactObject{
Path: *object.Key,
Size: object.Size,
IsDir: false,
var artifactList []ArtifactObject
for paginator.HasMorePages() {
page, err := paginator.NextPage(context.TODO())
if err != nil {
return "", nil, eris.Wrap(err, "error getting s3 page objects")
}
log.Debugf("got %d objects from S3 storage for path: %s", len(page.Contents), path)
for _, object := range page.Contents {
artifactList = append(artifactList, ArtifactObject{
Path: *object.Key,
Size: object.Size,
IsDir: false,
})
}
}

if output.NextContinuationToken != nil {
return *output.NextContinuationToken, s.config.ArtifactRoot, artifactList, nil
}

return "", fmt.Sprintf("s3://%s", bucket), artifactList, nil
return fmt.Sprintf("s3://%s", bucket), artifactList, nil
}
2 changes: 1 addition & 1 deletion pkg/api/mlflow/service/artifact/storage/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func (o ArtifactObject) IsDirectory() bool {

// Provider provides and interface to work with artifact storage.
type Provider interface {
List(artifactURI, path, nextPageToken string) (string, string, []ArtifactObject, error)
List(artifactURI, path string) (string, []ArtifactObject, error)
}

// NewArtifactStorage creates new Artifact storage.
Expand Down

0 comments on commit f7f8d0c

Please sign in to comment.