diff --git a/pkg/download/downloader.go b/pkg/download/downloader.go index 3d20c9a8a..5da0e70eb 100644 --- a/pkg/download/downloader.go +++ b/pkg/download/downloader.go @@ -225,7 +225,10 @@ func (d *Downloader) Resolve(req *base.Request) (rr *ResolveResult, err error) { return } - res := d.triggerOnResolve(req) + res, err := d.triggerOnResolve(req) + if err != nil { + return + } if res != nil && len(res.Files) > 0 { rr = &ResolveResult{ Res: res, @@ -525,9 +528,6 @@ func (d *Downloader) emit(eventKey EventKey, task *Task, errs ...error) { Task: task, Err: err, }) - if eventKey == EventKeyError { - d.emit(EventKeyFinally, task, err) - } } } @@ -572,27 +572,36 @@ func (d *Downloader) getProtocolConfig(name string, v any) bool { func (d *Downloader) watch(task *Task) { err := task.fetcher.Wait() if err != nil { - task.updateStatus(base.DownloadStatusError) - d.storage.Put(bucketTask, task.ID, task.clone()) + d.doOnError(task, err) + return + } + task.Progress.Used = task.timer.Used() + if task.Meta.Res.Size == 0 { + task.Meta.Res.Size = task.fetcher.Progress().TotalDownloaded() + } + used := task.Progress.Used / int64(time.Second) + if used == 0 { + used = 1 + } + totalSize := task.Meta.Res.Size + task.Progress.Speed = totalSize / used + task.Progress.Downloaded = totalSize + task.updateStatus(base.DownloadStatusDone) + d.storage.Put(bucketTask, task.ID, task.clone()) + d.emit(EventKeyDone, task) + d.emit(EventKeyFinally, task, err) + d.notifyRunning() +} + +func (d *Downloader) doOnError(task *Task, err error) { + d.Logger.Warn().Err(err).Msgf("task download failed, task id: %s", task.ID) + task.updateStatus(base.DownloadStatusError) + d.triggerOnError(task, err) + if task.Status == base.DownloadStatusError { d.emit(EventKeyError, task, err) - } else { - task.Progress.Used = task.timer.Used() - if task.Meta.Res.Size == 0 { - task.Meta.Res.Size = task.fetcher.Progress().TotalDownloaded() - } - used := task.Progress.Used / int64(time.Second) - if used == 0 { - used = 1 - } - totalSize := task.Meta.Res.Size - task.Progress.Speed = totalSize / used - task.Progress.Downloaded = totalSize - task.updateStatus(base.DownloadStatusDone) - d.storage.Put(bucketTask, task.ID, task.clone()) - d.emit(EventKeyDone, task) d.emit(EventKeyFinally, task, err) + d.notifyRunning() } - d.notifyRunning() } func (d *Downloader) restoreFetcher(task *Task) error { @@ -697,26 +706,18 @@ func (d *Downloader) doStart(task *Task) (err error) { } } - cloneTask := task.clone() isCreate := task.Status == base.DownloadStatusReady - task.updateStatus(base.DownloadStatusRunning) doStart := func() error { task.lock.Lock() defer task.lock.Unlock() - req := d.triggerOnStart(cloneTask) - if req != nil { - task.Meta.Req = req - task.fetcher.Meta().Req = req - } + d.triggerOnStart(task) + task.updateStatus(base.DownloadStatusRunning) if task.Meta.Res == nil { err := task.fetcher.Resolve(task.Meta.Req) if err != nil { - task.updateStatus(base.DownloadStatusError) - d.storage.Put(bucketTask, task.ID, task.clone()) - d.emit(EventKeyError, task, err) return err } task.Meta.Res = task.fetcher.Meta().Res @@ -762,7 +763,7 @@ func (d *Downloader) doStart(task *Task) (err error) { go func() { err := doStart() if err != nil { - d.Logger.Error().Stack().Err(err).Msgf("start task failed, task id: %s", task.ID) + d.doOnError(task, err) } }() diff --git a/pkg/download/engine/engine.go b/pkg/download/engine/engine.go index 903b57f5b..969a5cefe 100644 --- a/pkg/download/engine/engine.go +++ b/pkg/download/engine/engine.go @@ -3,6 +3,7 @@ package engine import ( _ "embed" "errors" + gojaerror "github.com/GopeedLab/gopeed/pkg/download/engine/inject/error" "github.com/GopeedLab/gopeed/pkg/download/engine/inject/file" "github.com/GopeedLab/gopeed/pkg/download/engine/inject/formdata" "github.com/GopeedLab/gopeed/pkg/download/engine/inject/vm" @@ -125,6 +126,9 @@ func NewEngine(cfg *Config) *Engine { runtime.SetFieldNameMapper(goja.TagFieldNameMapper("json", true)) vm.Enable(runtime) gojaurl.Enable(runtime) + if err := gojaerror.Enable(runtime); err != nil { + return + } if err := file.Enable(runtime); err != nil { return } @@ -167,7 +171,11 @@ func resolveResult(value goja.Value) (any, error) { case goja.PromiseStateFulfilled: return p.Result().Export(), nil case goja.PromiseStateRejected: - return nil, errors.New(p.Result().String()) + if err, ok := p.Result().Export().(error); ok { + return nil, err + } else { + return nil, errors.New(p.Result().String()) + } } } return export, nil diff --git a/pkg/download/engine/engine_test.go b/pkg/download/engine/engine_test.go index fbab1a128..63e0281eb 100644 --- a/pkg/download/engine/engine_test.go +++ b/pkg/download/engine/engine_test.go @@ -7,7 +7,9 @@ import ( "errors" "fmt" "github.com/GopeedLab/gopeed/internal/test" + gojaerror "github.com/GopeedLab/gopeed/pkg/download/engine/inject/error" "github.com/GopeedLab/gopeed/pkg/download/engine/inject/file" + gojautil "github.com/GopeedLab/gopeed/pkg/download/engine/util" "github.com/GopeedLab/gopeed/pkg/util" "github.com/dop251/goja" "io" @@ -20,6 +22,7 @@ import ( ) func TestPolyfill(t *testing.T) { + doTestPolyfill(t, "MessageError") doTestPolyfill(t, "XMLHttpRequest") doTestPolyfill(t, "Blob") doTestPolyfill(t, "FormData") @@ -27,6 +30,16 @@ func TestPolyfill(t *testing.T) { doTestPolyfill(t, "__gopeed_create_vm") } +func TestError(t *testing.T) { + engine := NewEngine(nil) + _, err := engine.RunString(` + throw new MessageError('test'); + `) + if me, ok := gojautil.AssertError[*gojaerror.MessageError](err); !ok { + t.Fatalf("expect MessageError, but got %v", me) + } +} + func TestFetch(t *testing.T) { server := startServer() defer server.Close() diff --git a/pkg/download/engine/inject/error/module.go b/pkg/download/engine/inject/error/module.go new file mode 100644 index 000000000..761d61f68 --- /dev/null +++ b/pkg/download/engine/inject/error/module.go @@ -0,0 +1,29 @@ +package error + +import ( + "github.com/dop251/goja" +) + +type MessageError struct { + Message string `json:"message"` +} + +func (e *MessageError) Error() string { + return e.Message +} + +func Enable(runtime *goja.Runtime) error { + messageError := runtime.ToValue(func(call goja.ConstructorCall) *goja.Object { + var message string + if len(call.Arguments) > 0 { + message = call.Arguments[0].String() + } + instance := &MessageError{ + Message: message, + } + instanceValue := runtime.ToValue(instance).(*goja.Object) + instanceValue.SetPrototype(call.This.Prototype()) + return instanceValue + }) + return runtime.Set("MessageError", messageError) +} diff --git a/pkg/download/engine/inject/util.go b/pkg/download/engine/inject/util.go deleted file mode 100644 index d00504be1..000000000 --- a/pkg/download/engine/inject/util.go +++ /dev/null @@ -1,9 +0,0 @@ -package inject - -import ( - "github.com/dop251/goja" -) - -func ThrowTypeError(vm *goja.Runtime, msg string) { - panic(vm.NewTypeError(msg)) -} diff --git a/pkg/download/engine/inject/xhr/module.go b/pkg/download/engine/inject/xhr/module.go index 8be6d4862..d91c56ac0 100644 --- a/pkg/download/engine/inject/xhr/module.go +++ b/pkg/download/engine/inject/xhr/module.go @@ -2,9 +2,9 @@ package xhr import ( "bytes" - "github.com/GopeedLab/gopeed/pkg/download/engine/inject" "github.com/GopeedLab/gopeed/pkg/download/engine/inject/file" "github.com/GopeedLab/gopeed/pkg/download/engine/inject/formdata" + "github.com/GopeedLab/gopeed/pkg/download/engine/util" "github.com/dop251/goja" "io" "mime/multipart" @@ -316,7 +316,7 @@ func (xhr *XMLHttpRequest) parseData(data goja.Value) any { func Enable(runtime *goja.Runtime, proxyUrl *url.URL) error { progressEvent := runtime.ToValue(func(call goja.ConstructorCall) *goja.Object { if len(call.Arguments) < 1 { - inject.ThrowTypeError(runtime, "Failed to construct 'ProgressEvent': 1 argument required, but only 0 present.") + util.ThrowTypeError(runtime, "Failed to construct 'ProgressEvent': 1 argument required, but only 0 present.") } instance := &ProgressEvent{ Type: call.Argument(0).String(), diff --git a/pkg/download/engine/util/util.go b/pkg/download/engine/util/util.go new file mode 100644 index 000000000..c490c367d --- /dev/null +++ b/pkg/download/engine/util/util.go @@ -0,0 +1,24 @@ +package util + +import ( + "github.com/dop251/goja" +) + +func ThrowTypeError(vm *goja.Runtime, msg string) { + panic(vm.NewTypeError(msg)) +} + +func AssertError[T error](err error) (t T, r bool) { + if err == nil { + return + } + if e, ok := err.(T); ok { + return e, true + } + if e, ok := err.(*goja.Exception); ok { + if ee, okk := e.Value().Export().(T); okk { + return ee, true + } + } + return +} diff --git a/pkg/download/extension.go b/pkg/download/extension.go index d2a766f6d..71218987a 100644 --- a/pkg/download/extension.go +++ b/pkg/download/extension.go @@ -6,6 +6,8 @@ import ( "github.com/GopeedLab/gopeed/internal/logger" "github.com/GopeedLab/gopeed/pkg/base" "github.com/GopeedLab/gopeed/pkg/download/engine" + gojaerror "github.com/GopeedLab/gopeed/pkg/download/engine/inject/error" + gojautil "github.com/GopeedLab/gopeed/pkg/download/engine/util" "github.com/GopeedLab/gopeed/pkg/util" "github.com/dop251/goja" "github.com/go-git/go-git/v5" @@ -35,7 +37,7 @@ type ActivationEvent string const ( EventOnResolve ActivationEvent = "onResolve" EventOnStart ActivationEvent = "onStart" - //EventOnError ActivationEvent = "onError" + EventOnError ActivationEvent = "onError" //EventOnDone ActivationEvent = "onDone" ) @@ -238,8 +240,8 @@ func (d *Downloader) parseExtensionByPath(path string) (*Extension, error) { return &ext, nil } -func (d *Downloader) triggerOnResolve(req *base.Request) (res *base.Resource) { - doTrigger(d, +func (d *Downloader) triggerOnResolve(req *base.Request) (res *base.Resource, err error) { + err = doTrigger(d, EventOnResolve, req, &OnResolveContext{ @@ -264,12 +266,12 @@ func (d *Downloader) triggerOnResolve(req *base.Request) (res *base.Resource) { return } -func (d *Downloader) triggerOnStart(task *Task) (req *base.Request) { +func (d *Downloader) triggerOnStart(task *Task) { doTrigger(d, EventOnStart, task.Meta.Req, &OnStartContext{ - Task: task, + Task: NewExtensionTask(d, task), }, func(ext *Extension, gopeed *Instance, ctx *OnStartContext) { // Validate request structure @@ -278,14 +280,25 @@ func (d *Downloader) triggerOnStart(task *Task) (req *base.Request) { gopeed.Logger.logger.Warn().Err(err).Msgf("[%s] request invalid", ext.buildIdentity()) return } - req = ctx.Task.Meta.Req } }, ) return } -func doTrigger[T any](d *Downloader, event ActivationEvent, req *base.Request, ctx T, handler func(ext *Extension, gopeed *Instance, ctx T)) { +func (d *Downloader) triggerOnError(task *Task, err error) { + doTrigger(d, + EventOnError, + task.Meta.Req, + &OnErrorContext{ + Task: NewExtensionTask(d, task), + Error: err, + }, + nil, + ) +} + +func doTrigger[T any](d *Downloader, event ActivationEvent, req *base.Request, ctx T, handler func(ext *Extension, gopeed *Instance, ctx T)) error { // init extension global object gopeed := &Instance{ Events: make(InstanceEvents), @@ -347,13 +360,20 @@ func doTrigger[T any](d *Downloader, event ActivationEvent, req *base.Request, c gopeed.Logger.logger.Error().Err(err).Msgf("[%s] call function failed: %s", ext.buildIdentity(), event) return } - handler(ext, gopeed, ctx) + if handler != nil { + handler(ext, gopeed, ctx) + } } }() } } } - return + + // Only return MessageError + if me, ok := gojautil.AssertError[*gojaerror.MessageError](err); ok { + return me + } + return nil } func (d *Downloader) ExtensionPath(ext *Extension) string { @@ -581,10 +601,10 @@ func (h InstanceEvents) OnStart(fn goja.Callable) { h.register(EventOnStart, fn) } -//func (h InstanceEvents) OnError(fn goja.Callable) { -// h.register(HookEventOnError, fn) -//} -// +func (h InstanceEvents) OnError(fn goja.Callable) { + h.register(EventOnError, fn) +} + //func (h InstanceEvents) OnDone(fn goja.Callable) { // h.register(HookEventOnDone, fn) //} @@ -653,7 +673,33 @@ type OnResolveContext struct { } type OnStartContext struct { - Task *Task `json:"task"` + Task *ExtensionTask `json:"task"` +} + +type OnErrorContext struct { + Task *ExtensionTask `json:"task"` + Error error `json:"error"` +} + +type ExtensionTask struct { + *Task + + download *Downloader +} + +func NewExtensionTask(download *Downloader, task *Task) *ExtensionTask { + return &ExtensionTask{ + Task: task.clone(), + download: download, + } +} + +func (t *ExtensionTask) Continue() error { + return t.download.Continue(t.ID) +} + +func (t *ExtensionTask) Pause() error { + return t.download.Pause(t.ID) } func parseSettings(settings []*Setting) map[string]any { diff --git a/pkg/download/extension_test.go b/pkg/download/extension_test.go index 92d4357b5..59707722f 100644 --- a/pkg/download/extension_test.go +++ b/pkg/download/extension_test.go @@ -4,6 +4,7 @@ import ( "errors" "github.com/GopeedLab/gopeed/internal/logger" "github.com/GopeedLab/gopeed/pkg/base" + gojaerror "github.com/GopeedLab/gopeed/pkg/download/engine/inject/error" "github.com/dop251/goja" "os" "testing" @@ -165,7 +166,7 @@ func TestDownloader_UpgradeExtension(t *testing.T) { }) } -func TestDownloader_ExtensionByOnStart(t *testing.T) { +func TestDownloader_Extension_OnStart(t *testing.T) { downloadAndCheck := func(req *base.Request) { setupDownloader(func(downloader *Downloader) { if _, err := downloader.InstallExtensionByFolder("./testdata/extensions/on_start", false); err != nil { @@ -192,10 +193,10 @@ func TestDownloader_ExtensionByOnStart(t *testing.T) { } task := downloader.GetTask(id) if task.Meta.Req.URL != "https://github.com" { - panic("extension on start modify url error") + t.Fatalf("except url: https://github.com, actual: %s", task.Meta.Req.URL) } if task.Meta.Req.Labels["modified"] != "true" { - panic("extension on start modify label error") + t.Fatalf("except label: modified=true, actual: %s", task.Meta.Req.Labels["modified"]) } }) } @@ -207,13 +208,54 @@ func TestDownloader_ExtensionByOnStart(t *testing.T) { // label match downloadAndCheck(&base.Request{ - URL: "https://xxx.com", + URL: "https://test.com", Labels: map[string]string{ "test": "true", }, }) } +func TestDownloader_Extension_OnError(t *testing.T) { + setupDownloader(func(downloader *Downloader) { + if _, err := downloader.InstallExtensionByFolder("./testdata/extensions/on_error", false); err != nil { + t.Fatal(err) + } + errCh := make(chan error, 1) + downloader.Listener(func(event *Event) { + if event.Key == EventKeyFinally { + errCh <- event.Err + } + }) + id, err := downloader.CreateDirect(&base.Request{ + URL: "https://github.com/gopeed/test/404", + Labels: map[string]string{ + "test": "true", + }, + }, nil) + if err != nil { + t.Fatal(err) + } + select { + case err = <-errCh: + break + case <-time.After(time.Second * 10): + err = errors.New("timeout") + } + + if err != nil { + panic("extension on error download error: " + err.Error()) + } + // extension on error modify url and continue download + task := downloader.GetTask(id) + if task.Meta.Req.URL != "https://github.com" { + t.Fatalf("except url: https://github.com, actual: %s", task.Meta.Req.URL) + } + if task.Status != base.DownloadStatusDone { + t.Fatalf("except status is done, actual: %s", task.Status) + } + }) +} + func TestDownloader_Extension_Errors(t *testing.T) { setupDownloader(func(downloader *Downloader) { if _, err := downloader.InstallExtensionByFolder("./testdata/extensions/script_error", false); err != nil { @@ -244,6 +286,26 @@ func TestDownloader_Extension_Errors(t *testing.T) { t.Fatal("function error catch failed") } }) + + setupDownloader(func(downloader *Downloader) { + if _, err := downloader.InstallExtensionByFolder("./testdata/extensions/message_error", false); err != nil { + t.Fatal(err) + } + _, err := downloader.Resolve(&base.Request{ + URL: "https://github.com/test", + }) + if err == nil { + t.Fatalf("except error, but got nil") + } + me, ok := err.(*gojaerror.MessageError) + if !ok { + t.Fatalf("except MessageError type, but got %s", err) + } + want := "test" + if me.Error() != want { + t.Fatalf("except MessageError message %s, but got %s", want, me.Message) + } + }) } func TestDownloader_Extension_Settings(t *testing.T) { @@ -336,22 +398,6 @@ func TestDownloader_DeleteExtension(t *testing.T) { }) } -func TestDownloader_Extension_OnResolve(t *testing.T) { - setupDownloader(func(downloader *Downloader) { - installedExt, err := downloader.InstallExtensionByFolder("./testdata/extensions/settings_all", false) - if err != nil { - t.Fatal(err) - } - if err := downloader.DeleteExtension(installedExt.Identity); err != nil { - t.Fatal(err) - } - extensions := downloader.GetExtensions() - if len(extensions) != 0 { - t.Fatal("extension delete fail") - } - }) -} - func TestDownloader_Extension_Logger(t *testing.T) { logger := logger.NewLogger(false, "") il := newInstanceLogger(&Extension{ diff --git a/pkg/download/testdata/extensions/function_error/index.js b/pkg/download/testdata/extensions/function_error/index.js index e2c5ab1e3..bf7191ba4 100644 --- a/pkg/download/testdata/extensions/function_error/index.js +++ b/pkg/download/testdata/extensions/function_error/index.js @@ -1,5 +1,6 @@ gopeed.events.onResolve(async function (ctx) { const aaa = {}; + // access undefined property gopeed.logger.info(aaa.bbb.ccc); ctx.res = { diff --git a/pkg/download/testdata/extensions/message_error/index.js b/pkg/download/testdata/extensions/message_error/index.js new file mode 100644 index 000000000..07b42a9dd --- /dev/null +++ b/pkg/download/testdata/extensions/message_error/index.js @@ -0,0 +1,3 @@ +gopeed.events.onResolve(async function (ctx) { + throw new MessageError("test"); +}); diff --git a/pkg/download/testdata/extensions/message_error/manifest.json b/pkg/download/testdata/extensions/message_error/manifest.json new file mode 100644 index 000000000..dae42ead6 --- /dev/null +++ b/pkg/download/testdata/extensions/message_error/manifest.json @@ -0,0 +1,16 @@ +{ + "name": "message-error", + "title": "gopeed extension message error test", + "version": "0.0.1", + "scripts": [ + { + "event": "onResolve", + "match": { + "urls": [ + "*://github.com/*" + ] + }, + "entry": "index.js" + } + ] +} \ No newline at end of file diff --git a/pkg/download/testdata/extensions/on_error/index.js b/pkg/download/testdata/extensions/on_error/index.js new file mode 100644 index 000000000..23853d847 --- /dev/null +++ b/pkg/download/testdata/extensions/on_error/index.js @@ -0,0 +1,8 @@ +gopeed.events.onError(async function (ctx) { + gopeed.logger.info("url", ctx.task.meta.req.url); + gopeed.logger.info("error", ctx.error); + ctx.task.meta.req.url = "https://github.com"; + ctx.task.pause(); + ctx.task.continue(); +}); + diff --git a/pkg/download/testdata/extensions/on_error/manifest.json b/pkg/download/testdata/extensions/on_error/manifest.json new file mode 100644 index 000000000..15fff5a59 --- /dev/null +++ b/pkg/download/testdata/extensions/on_error/manifest.json @@ -0,0 +1,16 @@ +{ + "name": "on-error", + "title": "gopeed extension on error event test", + "version": "0.0.1", + "scripts": [ + { + "event": "onError", + "match": { + "labels": [ + "test" + ] + }, + "entry": "index.js" + } + ] +} \ No newline at end of file diff --git a/pkg/download/testdata/extensions/on_start/index.js b/pkg/download/testdata/extensions/on_start/index.js index cc6a6d5f6..e41109a55 100644 --- a/pkg/download/testdata/extensions/on_start/index.js +++ b/pkg/download/testdata/extensions/on_start/index.js @@ -1,4 +1,5 @@ gopeed.events.onStart(async function (ctx) { + gopeed.logger.info("url", ctx.task.meta.req.url); ctx.task.meta.req.url = "https://github.com"; ctx.task.meta.req.labels['modified'] = 'true'; }); diff --git a/pkg/rest/server_test.go b/pkg/rest/server_test.go index bc99bef2a..676b6fd40 100644 --- a/pkg/rest/server_test.go +++ b/pkg/rest/server_test.go @@ -204,12 +204,14 @@ func TestPauseAllAndContinueALLTasks(t *testing.T) { // continue all httpRequestCheckOk[any](http.MethodPut, "/api/v1/tasks/continue", nil) + time.Sleep(time.Millisecond * 100) tasks := httpRequestCheckOk[[]*download.Task](http.MethodGet, fmt.Sprintf("/api/v1/tasks?status=%s", base.DownloadStatusRunning), nil) if len(tasks) != cfg.MaxRunning { t.Errorf("ContinueAllTasks() got = %v, want %v", len(tasks), cfg.MaxRunning) } // pause all httpRequestCheckOk[any](http.MethodPut, "/api/v1/tasks/pause", nil) + time.Sleep(time.Millisecond * 100) tasks = httpRequestCheckOk[[]*download.Task](http.MethodGet, fmt.Sprintf("/api/v1/tasks?status=%s", base.DownloadStatusPause), nil) if len(tasks) != total { t.Errorf("PauseAllTasks() got = %v, want %v", len(tasks), total)