Skip to content

Commit

Permalink
Merge pull request #1025 from traPtitech/feat/entrypoint_validate
Browse files Browse the repository at this point in the history
エントリーポイントの確認
  • Loading branch information
ikura-hamu authored Oct 27, 2024
2 parents 0daed89 + f7a0f1c commit 7c5b62e
Show file tree
Hide file tree
Showing 9 changed files with 451 additions and 207 deletions.
3 changes: 2 additions & 1 deletion docs/openapi/v2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ servers:
- url: /api/v2 #oapi-codegenでproxy後にもmatchできるようにするため必要
info:
description: 'traP Collection v2'
version: '2.3.0'
version: '2.4.0'
title: 'traP Collection v2'
contact:
name: traP
Expand Down Expand Up @@ -857,6 +857,7 @@ paths:
$ref: '#/components/schemas/Error'
description: |
リクエストが不正である場合に返されます。
エントリーポイントが存在しない、zipファイルでないなどです。
"401":
$ref: '#/components/responses/TraPUnauthorized'
"404":
Expand Down
8 changes: 7 additions & 1 deletion src/handler/v2/game_file.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,13 @@ func (gameFile GameFile) PostGameFile(c echo.Context, gameID openapi.GameIDInPat
var err error
savedFile, err = gameFile.gameFileService.SaveGameFile(c.Request().Context(), r, values.NewGameIDFromUUID(gameID), fileType, entryPoint)
if errors.Is(err, service.ErrInvalidGameID) {
return echo.NewHTTPError(http.StatusNotFound, "invalid gameID")
return echo.NewHTTPError(http.StatusNotFound, "gameID not found")
}
if errors.Is(err, service.ErrNotZipFile) {
return echo.NewHTTPError(http.StatusBadRequest, "only zip file is allowed")
}
if errors.Is(err, service.ErrInvalidEntryPoint) {
return echo.NewHTTPError(http.StatusBadRequest, "invalid entry point")
}
if err != nil {
log.Printf("error: failed to save game file: %v\n", err)
Expand Down
20 changes: 20 additions & 0 deletions src/handler/v2/game_file_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,26 @@ func TestPostGameFile(t *testing.T) {
isErr: true,
statusCode: http.StatusNotFound,
},
{
description: "SaveGameFileがErrNotZipFileなので400",
fileType: openapi.Jar,
gameID: uuid.UUID(values.NewGameID()),
reader: bytes.NewReader([]byte("test")),
executeSaveGameFile: true,
saveGameFileErr: service.ErrNotZipFile,
isErr: true,
statusCode: http.StatusBadRequest,
},
{
description: "SaveGameFileがErrInvalidEntryPointなので400",
fileType: openapi.Jar,
gameID: uuid.UUID(values.NewGameID()),
reader: bytes.NewReader([]byte("test")),
executeSaveGameFile: true,
saveGameFileErr: service.ErrInvalidEntryPoint,
isErr: true,
statusCode: http.StatusBadRequest,
},
{
description: "SaveGameFileがエラーなので500",
fileType: openapi.Jar,
Expand Down
353 changes: 177 additions & 176 deletions src/handler/v2/openapi/openapi.gen.go

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions src/service/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,6 @@ var (
ErrDuplicateGameGenre = errors.New("duplicate game genre")
ErrDuplicateGameGenreName = errors.New("duplicate game genre name")
ErrInvalidGamesSortType = errors.New("invalid games sort type")
ErrNotZipFile = errors.New("not zip file")
ErrInvalidEntryPoint = errors.New("invalid entry point")
)
74 changes: 73 additions & 1 deletion src/service/v2/game_file.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
package v2

import (
"archive/zip"
"context"
"errors"
"fmt"
"io"
"net/url"
"os"
"slices"
"time"

"github.com/traPtitech/trap-collection-server/src/domain"
Expand Down Expand Up @@ -37,7 +40,52 @@ func NewGameFile(
}
}

func (*GameFile) checkZip(_ context.Context, reader io.Reader) (zr *zip.Reader, ok bool, err error) {
f, err := os.CreateTemp("", "game_file")
if err != nil {
return nil, false, fmt.Errorf("failed to create temp file: %w", err)
}
defer func() {
err = os.Remove(f.Name())
}()
defer func() {
err = f.Close()
}()

_, err = io.Copy(f, reader)
if err != nil {
return nil, false, fmt.Errorf("failed to copy file: %w", err)
}

fInfo, err := f.Stat()
if err != nil {
return nil, false, fmt.Errorf("failed to get file info: %w", err)
}
zr, err = zip.NewReader(f, fInfo.Size())
if errors.Is(err, zip.ErrFormat) {
return nil, false, nil
}
if err != nil {
return nil, false, fmt.Errorf("failed to open zip file: %w", err)
}

return zr, true, nil
}

func (*GameFile) checkEntryPointExist(_ context.Context, zr *zip.Reader, entryPoint values.GameFileEntryPoint) (bool, error) {
entryPointExists := slices.ContainsFunc(zr.File, func(zf *zip.File) bool {
return zf.Name == string(entryPoint) && !zf.FileInfo().IsDir()
})

if !entryPointExists {
return false, nil
}

return true, nil
}

func (gameFile *GameFile) SaveGameFile(ctx context.Context, reader io.Reader, gameID values.GameID, fileType values.GameFileType, entryPoint values.GameFileEntryPoint) (*domain.GameFile, error) {

var file *domain.GameFile
err := gameFile.db.Transaction(ctx, nil, func(ctx context.Context) error {
_, err := gameFile.gameRepository.GetGame(ctx, gameID, repository.LockTypeRecord)
Expand All @@ -53,6 +101,7 @@ func (gameFile *GameFile) SaveGameFile(ctx context.Context, reader io.Reader, ga
eg, ctx := errgroup.WithContext(ctx)
hashPr, hashPw := io.Pipe()
filePr, filePw := io.Pipe()
entryPointPr, entryPointPw := io.Pipe()

eg.Go(func() error {
defer hashPr.Close()
Expand Down Expand Up @@ -89,11 +138,34 @@ func (gameFile *GameFile) SaveGameFile(ctx context.Context, reader io.Reader, ga
return nil
})

eg.Go(func() error {
defer entryPointPr.Close()

zr, ok, err := gameFile.checkZip(ctx, entryPointPr)
if err != nil {
return fmt.Errorf("failed to check zip: %w", err)
}
if !ok {
return service.ErrNotZipFile
}

ok, err = gameFile.checkEntryPointExist(ctx, zr, entryPoint)
if err != nil {
return fmt.Errorf("failed to check entry point exist: %w", err)
}
if !ok {
return service.ErrInvalidEntryPoint
}

return nil
})

eg.Go(func() error {
defer hashPw.Close()
defer filePw.Close()
defer entryPointPw.Close()

mw := io.MultiWriter(hashPw, filePw)
mw := io.MultiWriter(hashPw, filePw, entryPointPw)
_, err = io.Copy(mw, reader)
if err != nil {
return fmt.Errorf("failed to copy file: %w", err)
Expand Down
Loading

0 comments on commit 7c5b62e

Please sign in to comment.