From 94b65604cc2042955799c48497e21d08ab1fcc0b Mon Sep 17 00:00:00 2001 From: Arran Ubels Date: Thu, 29 Jan 2026 14:34:36 +1100 Subject: [PATCH 1/2] refactor: extract HTTP routing and handlers into dedicated router files. --- cmd/gobookmarks/main_test.go | 63 ----- cmd/gobookmarks/serve.go | 444 +++-------------------------------- router.go | 155 ++++++++++++ router_helpers.go | 309 ++++++++++++++++++++++++ router_helpers_test.go | 67 ++++++ 5 files changed, 562 insertions(+), 476 deletions(-) create mode 100644 router.go create mode 100644 router_helpers.go create mode 100644 router_helpers_test.go diff --git a/cmd/gobookmarks/main_test.go b/cmd/gobookmarks/main_test.go index 88e00a9..f29446b 100644 --- a/cmd/gobookmarks/main_test.go +++ b/cmd/gobookmarks/main_test.go @@ -1,72 +1,9 @@ package main import ( - "context" - "errors" - "net/http" - "net/http/httptest" - "strings" "testing" - - gb "github.com/arran4/gobookmarks" - "github.com/gorilla/sessions" - "golang.org/x/oauth2" ) -func TestRunHandlerChain_UserErrorRedirect(t *testing.T) { - gb.SessionName = "testsess" - gb.SessionStore = sessions.NewCookieStore([]byte("secret")) - - req := httptest.NewRequest("GET", "/submit", nil) - req.Header.Set("Referer", "/form") - ctx := context.WithValue(req.Context(), gb.ContextValues("coreData"), &gb.CoreData{}) - req = req.WithContext(ctx) - - h := runHandlerChain(func(w http.ResponseWriter, r *http.Request) error { - return gb.NewUserError("bad input", errors.New("invalid")) - }) - - w := httptest.NewRecorder() - h(w, req) - res := w.Result() - if res.StatusCode != http.StatusSeeOther { - t.Fatalf("expected redirect, got %d", res.StatusCode) - } - loc := res.Header.Get("Location") - if !strings.Contains(loc, "error=bad+input") { - t.Fatalf("redirect missing error param: %s", loc) - } -} - -func TestRunTemplate_BufferedError(t *testing.T) { - gb.SessionName = "testsess" - gb.SessionStore = sessions.NewCookieStore([]byte("secret")) - gb.DBConnectionProvider = "" - - req := httptest.NewRequest("GET", "/", nil) - sess, _ := gb.SessionStore.New(req, gb.SessionName) - sess.Values["GithubUser"] = &gb.User{Login: "user"} - sess.Values["Token"] = &oauth2.Token{} - ctx := context.WithValue(req.Context(), gb.ContextValues("session"), sess) - ctx = context.WithValue(ctx, gb.ContextValues("provider"), "sql") - ctx = context.WithValue(ctx, gb.ContextValues("coreData"), &gb.CoreData{UserRef: "user"}) - req = req.WithContext(ctx) - - w := httptest.NewRecorder() - runTemplate("mainPage.gohtml")(w, req) - - body := w.Body.String() - if !strings.Contains(body, "Database error") { - t.Fatalf("expected database error message, got %q", body) - } - if strings.Count(body, "") != 1 { - t.Fatalf("unexpected partial content: %q", body) - } - if strings.Contains(body, "tab-list") { - t.Fatalf("unexpected partial page content: %q", body) - } -} - func TestLoadConfigUsesExternalURL(t *testing.T) { rc := NewRootCommand() diff --git a/cmd/gobookmarks/serve.go b/cmd/gobookmarks/serve.go index abbedb4..e3853e3 100644 --- a/cmd/gobookmarks/serve.go +++ b/cmd/gobookmarks/serve.go @@ -3,62 +3,51 @@ package main import ( "bytes" "context" - "crypto/ecdsa" - "crypto/elliptic" - "crypto/rand" - "crypto/x509" - "crypto/x509/pkix" "encoding/json" - "encoding/pem" "errors" "flag" "fmt" - . "github.com/arran4/gobookmarks" - "github.com/arran4/gorillamuxlogic" - "github.com/gorilla/mux" - "github.com/gorilla/securecookie" - "github.com/gorilla/sessions" - "io" "log" - "math/big" "net/http" - "net/url" "os" "os/signal" "path/filepath" - "reflect" "strconv" "strings" "sync" "time" + + . "github.com/arran4/gobookmarks" + "github.com/gorilla/securecookie" + "github.com/gorilla/sessions" ) type ServeCommand struct { parent Command Flags *flag.FlagSet - GithubClientID stringFlag - GithubSecret stringFlag - GitlabClientID stringFlag - GitlabSecret stringFlag - ExternalURL stringFlag - Namespace stringFlag - Title stringFlag + GithubClientID stringFlag + GithubSecret stringFlag + GitlabClientID stringFlag + GitlabSecret stringFlag + ExternalURL stringFlag + Namespace stringFlag + Title stringFlag FaviconCacheDir stringFlag FaviconCacheSize stringFlag FaviconMaxCacheCount stringFlag CommitsPerPage stringFlag GithubServer stringFlag - GitlabServer stringFlag - LocalGitPath stringFlag - DbProvider stringFlag - DbConn stringFlag - SessionKey stringFlag - ProviderOrder stringFlag - CssColumns boolFlag - NoFooter boolFlag - DevMode boolFlag - DumpConfig boolFlag + GitlabServer stringFlag + LocalGitPath stringFlag + DbProvider stringFlag + DbConn stringFlag + SessionKey stringFlag + ProviderOrder stringFlag + CssColumns boolFlag + NoFooter boolFlag + DevMode boolFlag + DumpConfig boolFlag } func (rc *RootCommand) NewServeCommand() (*ServeCommand, error) { @@ -262,103 +251,16 @@ func (c *ServeCommand) Execute(args []string) error { return errors.New("no providers available") } - r := mux.NewRouter() - - r.Use(UserAdderMiddleware) - r.Use(CoreAdderMiddleware) - - r.HandleFunc("/main.css", func(writer http.ResponseWriter, request *http.Request) { - _, _ = writer.Write(GetMainCSSData()) - }).Methods("GET") - r.HandleFunc("/favicon.ico", func(writer http.ResponseWriter, request *http.Request) { - _, _ = writer.Write(GetFavicon()) - }).Methods("GET") - - // Development helpers to toggle layout mode - if DevMode { - r.HandleFunc("/_css", runHandlerChain(EnableCssColumnsAction, redirectToHandler("/"))).Methods("GET") - r.HandleFunc("/_table", runHandlerChain(DisableCssColumnsAction, redirectToHandler("/"))).Methods("GET") - } - - // News - r.Handle("/", http.HandlerFunc(runTemplate("mainPage.gohtml"))).Methods("GET") - r.Handle("/tab", http.HandlerFunc(runTemplate("mainPage.gohtml"))).Methods("GET") - r.Handle("/tab/{tab}", http.HandlerFunc(runTemplate("mainPage.gohtml"))).Methods("GET") - r.HandleFunc("/", runHandlerChain(TaskDoneAutoRefreshPage)).Methods("POST") - - r.HandleFunc("/edit", runTemplate("loginPage.gohtml")).Methods("GET").MatcherFunc(gorillamuxlogic.Not(RequiresAnAccount())) - r.HandleFunc("/edit", runTemplate("edit.gohtml")).Methods("GET").MatcherFunc(RequiresAnAccount()) - r.HandleFunc("/edit", runTemplate("edit.gohtml")).Methods("POST").MatcherFunc(RequiresAnAccount()).MatcherFunc(HasError()) - r.HandleFunc("/edit", runHandlerChain(BookmarksEditSaveAction, redirectToHandlerBranchToRef("/"))).Methods("POST").MatcherFunc(RequiresAnAccount()).MatcherFunc(TaskMatcher(TaskSave)) - r.HandleFunc("/edit", runHandlerChain(BookmarksEditSaveAction, TaskDoneAutoRefreshPage)).Methods("POST").MatcherFunc(RequiresAnAccount()).MatcherFunc(TaskMatcher(TaskSaveAndDone)) - r.HandleFunc("/edit", runHandlerChain(BookmarksEditSaveAction, StopEditMode, redirectToHandlerBranchToRef("/"))).Methods("POST").MatcherFunc(RequiresAnAccount()).MatcherFunc(TaskMatcher(TaskSaveAndStopEditing)) - r.HandleFunc("/edit", runHandlerChain(BookmarksEditCreateAction, redirectToHandlerBranchToRef("/"))).Methods("POST").MatcherFunc(RequiresAnAccount()).MatcherFunc(TaskMatcher("Create")) - r.HandleFunc("/edit", runHandlerChain(TaskDoneAutoRefreshPage)).Methods("POST") - - r.HandleFunc("/startEditMode", runHandlerChain(StartEditMode, redirectToHandlerTabPage("/"))).Methods("POST", "GET").MatcherFunc(RequiresAnAccount()) - r.HandleFunc("/stopEditMode", runHandlerChain(StopEditMode, redirectToHandlerTabPage("/"))).Methods("POST", "GET").MatcherFunc(RequiresAnAccount()) - - r.HandleFunc("/editCategory", runTemplate("loginPage.gohtml")).Methods("GET").MatcherFunc(gorillamuxlogic.Not(RequiresAnAccount())) - r.HandleFunc("/editCategory", runHandlerChain(EditCategoryPage)).Methods("GET").MatcherFunc(RequiresAnAccount()) - r.HandleFunc("/editCategory", runHandlerChain(CategoryEditSaveAction, redirectToHandlerBranchToRef("/"))).Methods("POST").MatcherFunc(RequiresAnAccount()).MatcherFunc(TaskMatcher(TaskSave)) - r.HandleFunc("/editCategory", runHandlerChain(CategoryEditSaveAction, TaskDoneAutoRefreshPage)).Methods("POST").MatcherFunc(RequiresAnAccount()).MatcherFunc(TaskMatcher(TaskSaveAndDone)) - r.HandleFunc("/editCategory", runHandlerChain(CategoryEditSaveAction, StopEditMode, redirectToHandlerBranchToRef("/"))).Methods("POST").MatcherFunc(RequiresAnAccount()).MatcherFunc(TaskMatcher(TaskSaveAndStopEditing)) - r.HandleFunc("/editCategory", runHandlerChain(TaskDoneAutoRefreshPage)).Methods("POST") - r.HandleFunc("/addCategory", runTemplate("loginPage.gohtml")).Methods("GET").MatcherFunc(gorillamuxlogic.Not(RequiresAnAccount())) - r.HandleFunc("/addCategory", runHandlerChain(AddCategoryPage)).Methods("GET").MatcherFunc(RequiresAnAccount()) - r.HandleFunc("/addCategory", runHandlerChain(CategoryAddSaveAction, redirectToHandlerBranchToRef("/"))).Methods("POST").MatcherFunc(RequiresAnAccount()).MatcherFunc(TaskMatcher(TaskSave)) - r.HandleFunc("/addCategory", runHandlerChain(CategoryAddSaveAction, TaskDoneAutoRefreshPage)).Methods("POST").MatcherFunc(RequiresAnAccount()).MatcherFunc(TaskMatcher(TaskSaveAndDone)) - r.HandleFunc("/addCategory", runHandlerChain(CategoryAddSaveAction, StopEditMode, redirectToHandlerBranchToRef("/"))).Methods("POST").MatcherFunc(RequiresAnAccount()).MatcherFunc(TaskMatcher(TaskSaveAndStopEditing)) - r.HandleFunc("/addCategory", runHandlerChain(TaskDoneAutoRefreshPage)).Methods("POST") - r.HandleFunc("/moveCategory", runHandlerChain(CategoryMoveBeforeAction)).Methods("POST").MatcherFunc(RequiresAnAccount()) - r.HandleFunc("/moveCategoryEnd", runHandlerChain(CategoryMoveEndAction)).Methods("POST").MatcherFunc(RequiresAnAccount()) - r.HandleFunc("/moveCategoryNewColumn", runHandlerChain(CategoryMoveNewColumnAction)).Methods("POST").MatcherFunc(RequiresAnAccount()) - - r.HandleFunc("/editTab", runTemplate("loginPage.gohtml")).Methods("GET").MatcherFunc(gorillamuxlogic.Not(RequiresAnAccount())) - r.HandleFunc("/editTab", runHandlerChain(EditTabPage)).Methods("GET").MatcherFunc(RequiresAnAccount()) - r.HandleFunc("/editTab", runHandlerChain(TabEditSaveAction, redirectToHandlerBranchToRef("/"))).Methods("POST").MatcherFunc(RequiresAnAccount()).MatcherFunc(TaskMatcher(TaskSave)) - r.HandleFunc("/editTab", runHandlerChain(TabEditSaveAction, TaskDoneAutoRefreshPage)).Methods("POST").MatcherFunc(RequiresAnAccount()).MatcherFunc(TaskMatcher(TaskSaveAndDone)) - r.HandleFunc("/editTab", runHandlerChain(TabEditSaveAction, StopEditMode, redirectToHandlerBranchToRef("/"))).Methods("POST").MatcherFunc(RequiresAnAccount()).MatcherFunc(TaskMatcher(TaskSaveAndStopEditing)) - r.HandleFunc("/editTab", runHandlerChain(TaskDoneAutoRefreshPage)).Methods("POST") - r.HandleFunc("/tab/{tab}/edit", runTemplate("loginPage.gohtml")).Methods("GET").MatcherFunc(gorillamuxlogic.Not(RequiresAnAccount())) - r.HandleFunc("/tab/{tab}/edit", runHandlerChain(EditTabPage)).Methods("GET").MatcherFunc(RequiresAnAccount()) - r.HandleFunc("/tab/{tab}/edit", runHandlerChain(TabEditSaveAction, redirectToHandlerBranchToRef("/"))).Methods("POST").MatcherFunc(RequiresAnAccount()).MatcherFunc(TaskMatcher(TaskSave)) - r.HandleFunc("/tab/{tab}/edit", runHandlerChain(TabEditSaveAction, TaskDoneAutoRefreshPage)).Methods("POST").MatcherFunc(RequiresAnAccount()).MatcherFunc(TaskMatcher(TaskSaveAndDone)) - r.HandleFunc("/tab/{tab}/edit", runHandlerChain(TabEditSaveAction, StopEditMode, redirectToHandlerBranchToRef("/"))).Methods("POST").MatcherFunc(RequiresAnAccount()).MatcherFunc(TaskMatcher(TaskSaveAndStopEditing)) - r.HandleFunc("/tab/{tab}/edit", runHandlerChain(TaskDoneAutoRefreshPage)).Methods("POST") - - r.HandleFunc("/editPage", runTemplate("loginPage.gohtml")).Methods("GET").MatcherFunc(gorillamuxlogic.Not(RequiresAnAccount())) - r.HandleFunc("/editPage", runHandlerChain(EditPagePage)).Methods("GET").MatcherFunc(RequiresAnAccount()) - r.HandleFunc("/editPage", runHandlerChain(PageEditSaveAction, redirectToHandlerBranchToRef("/"))).Methods("POST").MatcherFunc(RequiresAnAccount()).MatcherFunc(TaskMatcher(TaskSave)) - r.HandleFunc("/editPage", runHandlerChain(PageEditSaveAction, TaskDoneAutoRefreshPage)).Methods("POST").MatcherFunc(RequiresAnAccount()).MatcherFunc(TaskMatcher(TaskSaveAndDone)) - r.HandleFunc("/editPage", runHandlerChain(PageEditSaveAction, StopEditMode, redirectToHandlerBranchToRef("/"))).Methods("POST").MatcherFunc(RequiresAnAccount()).MatcherFunc(TaskMatcher(TaskSaveAndStopEditing)) - r.HandleFunc("/editPage", runHandlerChain(TaskDoneAutoRefreshPage)).Methods("POST") - - r.HandleFunc("/moveTab", runHandlerChain(MoveTabAction)).Methods("POST").MatcherFunc(RequiresAnAccount()) - r.HandleFunc("/movePage", runHandlerChain(MovePageAction)).Methods("POST").MatcherFunc(RequiresAnAccount()) - r.HandleFunc("/tab/{tab}/movePage", runHandlerChain(MovePageAction)).Methods("POST").MatcherFunc(RequiresAnAccount()) - r.HandleFunc("/moveEntry", runHandlerChain(MoveEntryAction)).Methods("POST").MatcherFunc(RequiresAnAccount()) - r.HandleFunc("/tab/{tab}/moveEntry", runHandlerChain(MoveEntryAction)).Methods("POST").MatcherFunc(RequiresAnAccount()) - - r.HandleFunc("/history", runTemplate("loginPage.gohtml")).Methods("GET").MatcherFunc(gorillamuxlogic.Not(RequiresAnAccount())) - r.HandleFunc("/history", runTemplate("history.gohtml")).Methods("GET").MatcherFunc(RequiresAnAccount()) - - r.HandleFunc("/history/commits", runTemplate("loginPage.gohtml")).Methods("GET").MatcherFunc(gorillamuxlogic.Not(RequiresAnAccount())) - r.HandleFunc("/status", runTemplate("statusPage.gohtml")).Methods("GET") - r.HandleFunc("/history/commits", runTemplate("historyCommits.gohtml")).Methods("GET").MatcherFunc(RequiresAnAccount()) - - r.HandleFunc("/login", runTemplate("loginPage.gohtml")).Methods("GET") - r.HandleFunc("/login/git", runTemplate("gitLoginPage.gohtml")).Methods("GET") - r.HandleFunc("/login/git", runHandlerChain(GitLoginAction, redirectToHandler("/"))).Methods("POST") - r.HandleFunc("/signup/git", runHandlerChain(GitSignupAction, redirectToHandler("/login/git"))).Methods("POST") - r.HandleFunc("/login/sql", runTemplate("sqlLoginPage.gohtml")).Methods("GET") - r.HandleFunc("/login/sql", runHandlerChain(SqlLoginAction, redirectToHandler("/"))).Methods("POST") - r.HandleFunc("/signup/sql", runHandlerChain(SqlSignupAction, redirectToHandler("/login/sql"))).Methods("POST") - r.HandleFunc("/login/{provider}", runHandlerChain(LoginWithProvider)).Methods("GET") - r.HandleFunc("/logout", runHandlerChain(UserLogoutAction, runTemplate("logoutPage.gohtml"))).Methods("GET") - r.HandleFunc("/oauth2Callback", runHandlerChain(Oauth2CallbackPage, redirectToHandler("/"))).Methods("GET") - - r.HandleFunc("/proxy/favicon", FaviconProxyHandler).Methods("GET") + // Create RouterConfig + routerCfg := &RouterConfig{ + SessionStore: SessionStore, // Globals should be initialized by now (lines 248-257) + SessionName: SessionName, + ExternalURL: cfg.ExternalURL, + BaseURL: "", // Root + DevMode: *cfg.DevMode, + } + + r := NewRouter(routerCfg) http.Handle("/", r) @@ -453,290 +355,6 @@ func splitList(s string) []string { return out } -func CreatePEMFiles() { - notBefore := time.Now() - notAfter := notBefore.Add(365 * 24 * time.Hour) // Valid for 1 year - - serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) - if err != nil { - log.Fatalf("Failed to generate serial number: %v", err) - } - - template := x509.Certificate{ - SerialNumber: serialNumber, - Subject: pkix.Name{ - Organization: []string{"Your Organization"}, - }, - NotBefore: notBefore, - NotAfter: notAfter, - KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, - BasicConstraintsValid: true, - } - - priv, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader) - if err != nil { - log.Fatalf("Failed to generate private key: %v", err) - } - - derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) - if err != nil { - log.Fatalf("Failed to create certificate: %v", err) - } - - certFile, err := os.Create("cert.pem") - if err != nil { - log.Fatalf("Failed to create cert.pem file: %v", err) - } - defer certFile.Close() - if err := pem.Encode(certFile, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil { - log.Fatalf("Failed to write data to cert.pem: %v", err) - } - - keyFile, err := os.Create("key.pem") - if err != nil { - log.Fatalf("Failed to create key.pem file: %v", err) - } - defer keyFile.Close() - privBytes, err := x509.MarshalECPrivateKey(priv) - if err != nil { - log.Fatalf("Failed to marshal private key: %v", err) - } - if err := pem.Encode(keyFile, &pem.Block{Type: "EC PRIVATE KEY", Bytes: privBytes}); err != nil { - log.Fatalf("Failed to write data to key.pem: %v", err) - } -} - -func runHandlerChain(chain ...any) func(http.ResponseWriter, *http.Request) { - return func(w http.ResponseWriter, r *http.Request) { - for _, each := range chain { - switch each := each.(type) { - case http.Handler: - each.ServeHTTP(w, r) - case http.HandlerFunc: - each(w, r) - case func(http.ResponseWriter, *http.Request): - each(w, r) - case func(http.ResponseWriter, *http.Request) error: - if err := each(w, r); err != nil { - if errors.Is(err, ErrHandled) { - return - } - if errors.Is(err, ErrSignedOut) { - if logoutErr := UserLogoutAction(w, r); logoutErr != nil { - log.Printf("logout error: %v", logoutErr) - } - type Data struct{ *CoreData } - if err := GetCompiledTemplates(NewFuncs(r)).ExecuteTemplate(w, "logoutPage.gohtml", Data{r.Context().Value(ContextValues("coreData")).(*CoreData)}); err != nil { - log.Printf("Logout Template Error: %s", err) - http.Error(w, "Internal Server Error", http.StatusInternalServerError) - } - return - } - - var uerr UserError - if errors.As(err, &uerr) { - dest := r.Referer() - if dest == "" { - dest = r.URL.Path - if q := r.URL.Query(); len(q) > 0 { - dest += "?" + q.Encode() - } - } - u, parseErr := url.Parse(dest) - if parseErr != nil { - log.Printf("user error parse referer: %v", parseErr) - } else { - q := u.Query() - q.Set("error", uerr.Msg) - u.RawQuery = q.Encode() - http.Redirect(w, r, u.String(), http.StatusSeeOther) - return - } - } - - var serr SystemError - display := "Internal error" - if errors.As(err, &serr) { - display = serr.Msg - err = serr.Err - } - - log.Printf("handler error: %v", err) - - type ErrorData struct { - *CoreData - Error string - } - if err := GetCompiledTemplates(NewFuncs(r)).ExecuteTemplate(w, "error.gohtml", ErrorData{ - CoreData: r.Context().Value(ContextValues("coreData")).(*CoreData), - Error: display, - }); err != nil { - log.Printf("Error Template Error: %s", err) - http.Error(w, "Internal Server Error", http.StatusInternalServerError) - } - return - } - default: - log.Panicf("unknown input: %s", reflect.TypeOf(each)) - } - } - } -} - -func runTemplate(tmpl string) func(http.ResponseWriter, *http.Request) { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - type Data struct { - *CoreData - Error string - } - - data := Data{ - CoreData: r.Context().Value(ContextValues("coreData")).(*CoreData), - Error: r.URL.Query().Get("error"), - } - - var buf bytes.Buffer - err := GetCompiledTemplates(NewFuncs(r)).ExecuteTemplate(&buf, tmpl, data) - if err == nil { - _, _ = io.Copy(w, &buf) - return - } - - if errors.Is(err, ErrSignedOut) { - if logoutErr := UserLogoutAction(w, r); logoutErr != nil { - log.Printf("logout error: %v", logoutErr) - } - type LogoutData struct{ *CoreData } - if tplErr := GetCompiledTemplates(NewFuncs(r)).ExecuteTemplate(w, "logoutPage.gohtml", LogoutData{data.CoreData}); tplErr != nil { - log.Printf("Logout Template Error: %v", tplErr) - http.Error(w, "Internal Server Error", http.StatusInternalServerError) - } - return - } - - var serr SystemError - display := "Internal error" - if errors.As(err, &serr) { - display = serr.Msg - err = serr.Err - } - - log.Printf("Template %s error: %v", tmpl, err) - - type ErrorData struct { - *CoreData - Error string - } - - if tplErr := GetCompiledTemplates(NewFuncs(r)).ExecuteTemplate(w, "error.gohtml", ErrorData{ - CoreData: data.CoreData, - Error: display, - }); tplErr != nil { - log.Printf("Error Template Error: %v", tplErr) - http.Error(w, "Internal Server Error", http.StatusInternalServerError) - } - }) -} - -func redirectToHandler(toUrl string) func(http.ResponseWriter, *http.Request) { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - http.Redirect(w, r, toUrl, http.StatusTemporaryRedirect) - }) -} - -func redirectToHandlerBranchToRef(toUrl string) func(http.ResponseWriter, *http.Request) { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - u, _ := url.Parse(toUrl) - qs := u.Query() - qs.Set("ref", "refs/heads/"+r.PostFormValue("branch")) - tab := TabFromRequest(r) - if v, ok := r.Context().Value(ContextValues("redirectTab")).(string); ok { - if parsed, err := strconv.Atoi(v); err == nil { - tab = parsed - } - } - u.Path = TabPath(tab) - page := r.PostFormValue("page") - if v, ok := r.Context().Value(ContextValues("redirectPage")).(string); ok { - page = v - } - if fragment := PageFragmentFromIndex(page); fragment != "" { - u.Fragment = fragment - } - if edit := r.URL.Query().Get("edit"); edit != "" { - qs.Set("edit", edit) - } - u.RawQuery = qs.Encode() - http.Redirect(w, r, u.String(), http.StatusTemporaryRedirect) - }) -} - -func redirectToHandlerTabPage(toUrl string) func(http.ResponseWriter, *http.Request) { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - u, _ := url.Parse(toUrl) - qs := u.Query() - u.Path = TabPath(TabFromRequest(r)) - if fragment := PageFragmentFromIndex(r.URL.Query().Get("page")); fragment != "" { - u.Fragment = fragment - } - if edit := r.URL.Query().Get("edit"); edit != "" { - qs.Set("edit", edit) - } - u.RawQuery = qs.Encode() - http.Redirect(w, r, u.String(), http.StatusTemporaryRedirect) - }) -} - -func RequiresAnAccount() mux.MatcherFunc { - return func(request *http.Request, match *mux.RouteMatch) bool { - var session *sessions.Session - sessioni := request.Context().Value(ContextValues("session")) - if sessioni == nil { - var err error - session, err = SessionStore.Get(request, SessionName) - if err != nil { - return false - } - } else { - var ok bool - session, ok = sessioni.(*sessions.Session) - if !ok { - return false - } - } - if v, ok := session.Values["version"].(string); !ok || v != version { - return false - } - githubUser, ok := session.Values["GithubUser"].(*User) - return ok && githubUser != nil - } -} - -func TaskMatcher(taskName string) mux.MatcherFunc { - return func(request *http.Request, match *mux.RouteMatch) bool { - return request.PostFormValue("task") == taskName - } -} - -func ModeMatcher(modeName string) mux.MatcherFunc { - return func(request *http.Request, match *mux.RouteMatch) bool { - return request.URL.Query().Get("mode") == modeName - } -} - -func HasError() mux.MatcherFunc { - return func(request *http.Request, match *mux.RouteMatch) bool { - return request.URL.Query().Has("error") - } -} - -func NoTask() mux.MatcherFunc { - return func(request *http.Request, match *mux.RouteMatch) bool { - return request.PostFormValue("task") == "" - } -} - func fileExists(filename string) bool { _, err := os.Stat(filename) return !os.IsNotExist(err) diff --git a/router.go b/router.go new file mode 100644 index 0000000..30e36bf --- /dev/null +++ b/router.go @@ -0,0 +1,155 @@ +package gobookmarks + +import ( + "database/sql" + "net/http" + + "github.com/arran4/gorillamuxlogic" + "github.com/gorilla/mux" + "github.com/gorilla/sessions" +) + +// RouterConfig holds the dependencies and configuration for the gobookmarks application +type RouterConfig struct { + // DB is the database connection + DB *sql.DB + + // UserProvider handles user authentication and lookup + UserProvider UserProvider + + // SessionStore handles session management + SessionStore sessions.Store + // SessionName is the name of the session cookie + SessionName string + + // BaseURL is the prefix for all routes (e.g. "/bookmarks") + BaseURL string + + // ExternalURL is the public facing URL + ExternalURL string + + // DevMode enables development features + DevMode bool +} + +// UserProvider defines the interface for external user management +type UserProvider interface { + // CurrentUser returns the current user from the request context + CurrentUser(r *http.Request) (*User, error) + // IsLoggedIn checks if a user is logged in + IsLoggedIn(r *http.Request) bool +} + +// NewRouter creates a new router with the given configuration +func NewRouter(cfg *RouterConfig) http.Handler { + // Initialize globals temporarily until full refactor + if cfg.ExternalURL != "" { + // ExternalUrl is a global in gobookmarks package, need to verify + } + // Note: We are not handling all globals yet, focusing on router structure first + + r := mux.NewRouter() + + r.Use(UserAdderMiddleware) // Middleware needs to be adapted to use UserProvider if present + r.Use(CoreAdderMiddleware) + + r.HandleFunc("/main.css", func(writer http.ResponseWriter, request *http.Request) { + _, _ = writer.Write(GetMainCSSData()) + }).Methods("GET") + r.HandleFunc("/favicon.ico", func(writer http.ResponseWriter, request *http.Request) { + _, _ = writer.Write(GetFavicon()) + }).Methods("GET") + + // Development helpers to toggle layout mode + if cfg.DevMode { + r.HandleFunc("/_css", runHandlerChain(EnableCssColumnsAction, redirectToHandler("/"))).Methods("GET") + r.HandleFunc("/_table", runHandlerChain(DisableCssColumnsAction, redirectToHandler("/"))).Methods("GET") + } + + // News + r.Handle("/", http.HandlerFunc(runTemplate("mainPage.gohtml"))).Methods("GET") + r.Handle("/tab", http.HandlerFunc(runTemplate("mainPage.gohtml"))).Methods("GET") + r.Handle("/tab/{tab}", http.HandlerFunc(runTemplate("mainPage.gohtml"))).Methods("GET") + r.HandleFunc("/", runHandlerChain(TaskDoneAutoRefreshPage)).Methods("POST") + + r.HandleFunc("/edit", runTemplate("loginPage.gohtml")).Methods("GET").MatcherFunc(gorillamuxlogic.Not(RequiresAnAccount())) + r.HandleFunc("/edit", runTemplate("edit.gohtml")).Methods("GET").MatcherFunc(RequiresAnAccount()) + r.HandleFunc("/edit", runTemplate("edit.gohtml")).Methods("POST").MatcherFunc(RequiresAnAccount()).MatcherFunc(HasError()) + r.HandleFunc("/edit", runHandlerChain(BookmarksEditSaveAction, redirectToHandlerBranchToRef("/"))).Methods("POST").MatcherFunc(RequiresAnAccount()).MatcherFunc(TaskMatcher(TaskSave)) + r.HandleFunc("/edit", runHandlerChain(BookmarksEditSaveAction, TaskDoneAutoRefreshPage)).Methods("POST").MatcherFunc(RequiresAnAccount()).MatcherFunc(TaskMatcher(TaskSaveAndDone)) + r.HandleFunc("/edit", runHandlerChain(BookmarksEditSaveAction, StopEditMode, redirectToHandlerBranchToRef("/"))).Methods("POST").MatcherFunc(RequiresAnAccount()).MatcherFunc(TaskMatcher(TaskSaveAndStopEditing)) + r.HandleFunc("/edit", runHandlerChain(BookmarksEditCreateAction, redirectToHandlerBranchToRef("/"))).Methods("POST").MatcherFunc(RequiresAnAccount()).MatcherFunc(TaskMatcher("Create")) + r.HandleFunc("/edit", runHandlerChain(TaskDoneAutoRefreshPage)).Methods("POST") + + r.HandleFunc("/startEditMode", runHandlerChain(StartEditMode, redirectToHandlerTabPage("/"))).Methods("POST", "GET").MatcherFunc(RequiresAnAccount()) + r.HandleFunc("/stopEditMode", runHandlerChain(StopEditMode, redirectToHandlerTabPage("/"))).Methods("POST", "GET").MatcherFunc(RequiresAnAccount()) + + r.HandleFunc("/editCategory", runTemplate("loginPage.gohtml")).Methods("GET").MatcherFunc(gorillamuxlogic.Not(RequiresAnAccount())) + r.HandleFunc("/editCategory", runHandlerChain(EditCategoryPage)).Methods("GET").MatcherFunc(RequiresAnAccount()) + r.HandleFunc("/editCategory", runHandlerChain(CategoryEditSaveAction, redirectToHandlerBranchToRef("/"))).Methods("POST").MatcherFunc(RequiresAnAccount()).MatcherFunc(TaskMatcher(TaskSave)) + r.HandleFunc("/editCategory", runHandlerChain(CategoryEditSaveAction, TaskDoneAutoRefreshPage)).Methods("POST").MatcherFunc(RequiresAnAccount()).MatcherFunc(TaskMatcher(TaskSaveAndDone)) + r.HandleFunc("/editCategory", runHandlerChain(CategoryEditSaveAction, StopEditMode, redirectToHandlerBranchToRef("/"))).Methods("POST").MatcherFunc(RequiresAnAccount()).MatcherFunc(TaskMatcher(TaskSaveAndStopEditing)) + r.HandleFunc("/editCategory", runHandlerChain(TaskDoneAutoRefreshPage)).Methods("POST") + r.HandleFunc("/addCategory", runTemplate("loginPage.gohtml")).Methods("GET").MatcherFunc(gorillamuxlogic.Not(RequiresAnAccount())) + r.HandleFunc("/addCategory", runHandlerChain(AddCategoryPage)).Methods("GET").MatcherFunc(RequiresAnAccount()) + r.HandleFunc("/addCategory", runHandlerChain(CategoryAddSaveAction, redirectToHandlerBranchToRef("/"))).Methods("POST").MatcherFunc(RequiresAnAccount()).MatcherFunc(TaskMatcher(TaskSave)) + r.HandleFunc("/addCategory", runHandlerChain(CategoryAddSaveAction, TaskDoneAutoRefreshPage)).Methods("POST").MatcherFunc(RequiresAnAccount()).MatcherFunc(TaskMatcher(TaskSaveAndDone)) + r.HandleFunc("/addCategory", runHandlerChain(CategoryAddSaveAction, StopEditMode, redirectToHandlerBranchToRef("/"))).Methods("POST").MatcherFunc(RequiresAnAccount()).MatcherFunc(TaskMatcher(TaskSaveAndStopEditing)) + r.HandleFunc("/addCategory", runHandlerChain(TaskDoneAutoRefreshPage)).Methods("POST") + r.HandleFunc("/moveCategory", runHandlerChain(CategoryMoveBeforeAction)).Methods("POST").MatcherFunc(RequiresAnAccount()) + r.HandleFunc("/moveCategoryEnd", runHandlerChain(CategoryMoveEndAction)).Methods("POST").MatcherFunc(RequiresAnAccount()) + r.HandleFunc("/moveCategoryNewColumn", runHandlerChain(CategoryMoveNewColumnAction)).Methods("POST").MatcherFunc(RequiresAnAccount()) + + r.HandleFunc("/editTab", runTemplate("loginPage.gohtml")).Methods("GET").MatcherFunc(gorillamuxlogic.Not(RequiresAnAccount())) + r.HandleFunc("/editTab", runHandlerChain(EditTabPage)).Methods("GET").MatcherFunc(RequiresAnAccount()) + r.HandleFunc("/editTab", runHandlerChain(TabEditSaveAction, redirectToHandlerBranchToRef("/"))).Methods("POST").MatcherFunc(RequiresAnAccount()).MatcherFunc(TaskMatcher(TaskSave)) + r.HandleFunc("/editTab", runHandlerChain(TabEditSaveAction, TaskDoneAutoRefreshPage)).Methods("POST").MatcherFunc(RequiresAnAccount()).MatcherFunc(TaskMatcher(TaskSaveAndDone)) + r.HandleFunc("/editTab", runHandlerChain(TabEditSaveAction, StopEditMode, redirectToHandlerBranchToRef("/"))).Methods("POST").MatcherFunc(RequiresAnAccount()).MatcherFunc(TaskMatcher(TaskSaveAndStopEditing)) + r.HandleFunc("/editTab", runHandlerChain(TaskDoneAutoRefreshPage)).Methods("POST") + r.HandleFunc("/tab/{tab}/edit", runTemplate("loginPage.gohtml")).Methods("GET").MatcherFunc(gorillamuxlogic.Not(RequiresAnAccount())) + r.HandleFunc("/tab/{tab}/edit", runHandlerChain(EditTabPage)).Methods("GET").MatcherFunc(RequiresAnAccount()) + r.HandleFunc("/tab/{tab}/edit", runHandlerChain(TabEditSaveAction, redirectToHandlerBranchToRef("/"))).Methods("POST").MatcherFunc(RequiresAnAccount()).MatcherFunc(TaskMatcher(TaskSave)) + r.HandleFunc("/tab/{tab}/edit", runHandlerChain(TabEditSaveAction, TaskDoneAutoRefreshPage)).Methods("POST").MatcherFunc(RequiresAnAccount()).MatcherFunc(TaskMatcher(TaskSaveAndDone)) + r.HandleFunc("/tab/{tab}/edit", runHandlerChain(TabEditSaveAction, StopEditMode, redirectToHandlerBranchToRef("/"))).Methods("POST").MatcherFunc(RequiresAnAccount()).MatcherFunc(TaskMatcher(TaskSaveAndStopEditing)) + r.HandleFunc("/tab/{tab}/edit", runHandlerChain(TaskDoneAutoRefreshPage)).Methods("POST") + + r.HandleFunc("/editPage", runTemplate("loginPage.gohtml")).Methods("GET").MatcherFunc(gorillamuxlogic.Not(RequiresAnAccount())) + r.HandleFunc("/editPage", runHandlerChain(EditPagePage)).Methods("GET").MatcherFunc(RequiresAnAccount()) + r.HandleFunc("/editPage", runHandlerChain(PageEditSaveAction, redirectToHandlerBranchToRef("/"))).Methods("POST").MatcherFunc(RequiresAnAccount()).MatcherFunc(TaskMatcher(TaskSave)) + r.HandleFunc("/editPage", runHandlerChain(PageEditSaveAction, TaskDoneAutoRefreshPage)).Methods("POST").MatcherFunc(RequiresAnAccount()).MatcherFunc(TaskMatcher(TaskSaveAndDone)) + r.HandleFunc("/editPage", runHandlerChain(PageEditSaveAction, StopEditMode, redirectToHandlerBranchToRef("/"))).Methods("POST").MatcherFunc(RequiresAnAccount()).MatcherFunc(TaskMatcher(TaskSaveAndStopEditing)) + r.HandleFunc("/editPage", runHandlerChain(TaskDoneAutoRefreshPage)).Methods("POST") + + r.HandleFunc("/moveTab", runHandlerChain(MoveTabAction)).Methods("POST").MatcherFunc(RequiresAnAccount()) + r.HandleFunc("/movePage", runHandlerChain(MovePageAction)).Methods("POST").MatcherFunc(RequiresAnAccount()) + r.HandleFunc("/tab/{tab}/movePage", runHandlerChain(MovePageAction)).Methods("POST").MatcherFunc(RequiresAnAccount()) + r.HandleFunc("/moveEntry", runHandlerChain(MoveEntryAction)).Methods("POST").MatcherFunc(RequiresAnAccount()) + r.HandleFunc("/tab/{tab}/moveEntry", runHandlerChain(MoveEntryAction)).Methods("POST").MatcherFunc(RequiresAnAccount()) + + r.HandleFunc("/history", runTemplate("loginPage.gohtml")).Methods("GET").MatcherFunc(gorillamuxlogic.Not(RequiresAnAccount())) + r.HandleFunc("/history", runTemplate("history.gohtml")).Methods("GET").MatcherFunc(RequiresAnAccount()) + + r.HandleFunc("/history/commits", runTemplate("loginPage.gohtml")).Methods("GET").MatcherFunc(gorillamuxlogic.Not(RequiresAnAccount())) + r.HandleFunc("/status", runTemplate("statusPage.gohtml")).Methods("GET") + r.HandleFunc("/history/commits", runTemplate("historyCommits.gohtml")).Methods("GET").MatcherFunc(RequiresAnAccount()) + + r.HandleFunc("/login", runTemplate("loginPage.gohtml")).Methods("GET") + r.HandleFunc("/login/git", runTemplate("gitLoginPage.gohtml")).Methods("GET") + r.HandleFunc("/login/git", runHandlerChain(GitLoginAction, redirectToHandler("/"))).Methods("POST") + r.HandleFunc("/signup/git", runHandlerChain(GitSignupAction, redirectToHandler("/login/git"))).Methods("POST") + r.HandleFunc("/login/sql", runTemplate("sqlLoginPage.gohtml")).Methods("GET") + r.HandleFunc("/login/sql", runHandlerChain(SqlLoginAction, redirectToHandler("/"))).Methods("POST") + r.HandleFunc("/signup/sql", runHandlerChain(SqlSignupAction, redirectToHandler("/login/sql"))).Methods("POST") + r.HandleFunc("/login/{provider}", runHandlerChain(LoginWithProvider)).Methods("GET") + r.HandleFunc("/logout", runHandlerChain(UserLogoutAction, runTemplate("logoutPage.gohtml"))).Methods("GET") + r.HandleFunc("/oauth2Callback", runHandlerChain(Oauth2CallbackPage, redirectToHandler("/"))).Methods("GET") + + r.HandleFunc("/proxy/favicon", FaviconProxyHandler).Methods("GET") + + // Create required files if missing + if !fileExists("cert.pem") || !fileExists("key.pem") { + CreatePEMFiles() + } + + return r +} diff --git a/router_helpers.go b/router_helpers.go new file mode 100644 index 0000000..713aec9 --- /dev/null +++ b/router_helpers.go @@ -0,0 +1,309 @@ +package gobookmarks + +import ( + "bytes" + "errors" + "io" + "log" + "net/http" + "net/url" + "os" + "reflect" + "strconv" + + "github.com/gorilla/mux" + "github.com/gorilla/sessions" + + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "math/big" + "time" +) + +func runHandlerChain(chain ...any) func(http.ResponseWriter, *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + for _, each := range chain { + switch each := each.(type) { + case http.HandlerFunc: + each(w, r) + case http.Handler: + each.ServeHTTP(w, r) + case func(http.ResponseWriter, *http.Request): + each(w, r) + case func(http.ResponseWriter, *http.Request) error: + if err := each(w, r); err != nil { + if errors.Is(err, ErrHandled) { + return + } + if errors.Is(err, ErrSignedOut) { + if logoutErr := UserLogoutAction(w, r); logoutErr != nil { + log.Printf("logout error: %v", logoutErr) + } + type Data struct{ *CoreData } + if err := GetCompiledTemplates(NewFuncs(r)).ExecuteTemplate(w, "logoutPage.gohtml", Data{r.Context().Value(ContextValues("coreData")).(*CoreData)}); err != nil { + log.Printf("Logout Template Error: %s", err) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + } + return + } + + var uerr UserError + if errors.As(err, &uerr) { + dest := r.Referer() + if dest == "" { + dest = r.URL.Path + if q := r.URL.Query(); len(q) > 0 { + dest += "?" + q.Encode() + } + } + u, parseErr := url.Parse(dest) + if parseErr != nil { + log.Printf("user error parse referer: %v", parseErr) + } else { + q := u.Query() + q.Set("error", uerr.Msg) + u.RawQuery = q.Encode() + http.Redirect(w, r, u.String(), http.StatusSeeOther) + return + } + } + + var serr SystemError + display := "Internal error" + if errors.As(err, &serr) { + display = serr.Msg + err = serr.Err + } + + log.Printf("handler error: %v", err) + + type ErrorData struct { + *CoreData + Error string + } + if err := GetCompiledTemplates(NewFuncs(r)).ExecuteTemplate(w, "error.gohtml", ErrorData{ + CoreData: r.Context().Value(ContextValues("coreData")).(*CoreData), + Error: display, + }); err != nil { + log.Printf("Error Template Error: %s", err) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + } + return + } + default: + log.Panicf("unknown input: %s", reflect.TypeOf(each)) + } + } + } +} + +func runTemplate(tmpl string) func(http.ResponseWriter, *http.Request) { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + type Data struct { + *CoreData + Error string + } + + data := Data{ + CoreData: r.Context().Value(ContextValues("coreData")).(*CoreData), + Error: r.URL.Query().Get("error"), + } + + var buf bytes.Buffer + err := GetCompiledTemplates(NewFuncs(r)).ExecuteTemplate(&buf, tmpl, data) + if err == nil { + _, _ = io.Copy(w, &buf) + return + } + + if errors.Is(err, ErrSignedOut) { + if logoutErr := UserLogoutAction(w, r); logoutErr != nil { + log.Printf("logout error: %v", logoutErr) + } + type LogoutData struct{ *CoreData } + if tplErr := GetCompiledTemplates(NewFuncs(r)).ExecuteTemplate(w, "logoutPage.gohtml", LogoutData{data.CoreData}); tplErr != nil { + log.Printf("Logout Template Error: %v", tplErr) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + } + return + } + + var serr SystemError + display := "Internal error" + if errors.As(err, &serr) { + display = serr.Msg + err = serr.Err + } + + log.Printf("Template %s error: %v", tmpl, err) + + type ErrorData struct { + *CoreData + Error string + } + + if tplErr := GetCompiledTemplates(NewFuncs(r)).ExecuteTemplate(w, "error.gohtml", ErrorData{ + CoreData: data.CoreData, + Error: display, + }); tplErr != nil { + log.Printf("Error Template Error: %v", tplErr) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + } + }) +} + +func redirectToHandler(toUrl string) func(http.ResponseWriter, *http.Request) { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, toUrl, http.StatusTemporaryRedirect) + }) +} + +func redirectToHandlerBranchToRef(toUrl string) func(http.ResponseWriter, *http.Request) { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + u, _ := url.Parse(toUrl) + qs := u.Query() + qs.Set("ref", "refs/heads/"+r.PostFormValue("branch")) + tab := TabFromRequest(r) + if v, ok := r.Context().Value(ContextValues("redirectTab")).(string); ok { + if parsed, err := strconv.Atoi(v); err == nil { + tab = parsed + } + } + u.Path = TabPath(tab) + page := r.PostFormValue("page") + if v, ok := r.Context().Value(ContextValues("redirectPage")).(string); ok { + page = v + } + if fragment := PageFragmentFromIndex(page); fragment != "" { + u.Fragment = fragment + } + if edit := r.URL.Query().Get("edit"); edit != "" { + qs.Set("edit", edit) + } + u.RawQuery = qs.Encode() + http.Redirect(w, r, u.String(), http.StatusTemporaryRedirect) + }) +} + +func redirectToHandlerTabPage(toUrl string) func(http.ResponseWriter, *http.Request) { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + u, _ := url.Parse(toUrl) + qs := u.Query() + u.Path = TabPath(TabFromRequest(r)) + if fragment := PageFragmentFromIndex(r.URL.Query().Get("page")); fragment != "" { + u.Fragment = fragment + } + if edit := r.URL.Query().Get("edit"); edit != "" { + qs.Set("edit", edit) + } + u.RawQuery = qs.Encode() + http.Redirect(w, r, u.String(), http.StatusTemporaryRedirect) + }) +} + +func RequiresAnAccount() mux.MatcherFunc { + return func(request *http.Request, match *mux.RouteMatch) bool { + var session *sessions.Session + sessioni := request.Context().Value(ContextValues("session")) + if sessioni == nil { + var err error + session, err = SessionStore.Get(request, SessionName) + if err != nil { + return false + } + } else { + var ok bool + session, ok = sessioni.(*sessions.Session) + if !ok { + return false + } + } + if v, ok := session.Values["version"].(string); !ok || v != version { + return false + } + githubUser, ok := session.Values["GithubUser"].(*User) + return ok && githubUser != nil + } +} + +func TaskMatcher(taskName string) mux.MatcherFunc { + return func(request *http.Request, match *mux.RouteMatch) bool { + return request.PostFormValue("task") == taskName + } +} + +func ModeMatcher(modeName string) mux.MatcherFunc { + return func(request *http.Request, match *mux.RouteMatch) bool { + return request.URL.Query().Get("mode") == modeName + } +} + +func HasError() mux.MatcherFunc { + return func(request *http.Request, match *mux.RouteMatch) bool { + return request.URL.Query().Has("error") + } +} + +func NoTask() mux.MatcherFunc { + return func(request *http.Request, match *mux.RouteMatch) bool { + return request.PostFormValue("task") == "" + } +} + +func CreatePEMFiles() { + notBefore := time.Now() + notAfter := notBefore.Add(365 * 24 * time.Hour) // Valid for 1 year + + serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) + if err != nil { + log.Fatalf("Failed to generate serial number: %v", err) + } + + template := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + Organization: []string{"Your Organization"}, + }, + NotBefore: notBefore, + NotAfter: notAfter, + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + } + + priv, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader) + if err != nil { + log.Fatalf("Failed to generate private key: %v", err) + } + + derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) + if err != nil { + log.Fatalf("Failed to create certificate: %v", err) + } + + certFile, err := os.Create("cert.pem") + if err != nil { + log.Fatalf("Failed to create cert.pem file: %v", err) + } + defer certFile.Close() + if err := pem.Encode(certFile, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil { + log.Fatalf("Failed to write data to cert.pem: %v", err) + } + + keyFile, err := os.Create("key.pem") + if err != nil { + log.Fatalf("Failed to create key.pem file: %v", err) + } + defer keyFile.Close() + privBytes, err := x509.MarshalECPrivateKey(priv) + if err != nil { + log.Fatalf("Failed to marshal private key: %v", err) + } + if err := pem.Encode(keyFile, &pem.Block{Type: "EC PRIVATE KEY", Bytes: privBytes}); err != nil { + log.Fatalf("Failed to write data to key.pem: %v", err) + } +} diff --git a/router_helpers_test.go b/router_helpers_test.go new file mode 100644 index 0000000..8788301 --- /dev/null +++ b/router_helpers_test.go @@ -0,0 +1,67 @@ +package gobookmarks + +import ( + "context" + "errors" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/gorilla/sessions" + "golang.org/x/oauth2" +) + +func TestRunHandlerChain_UserErrorRedirect(t *testing.T) { + SessionName = "testsess" + SessionStore = sessions.NewCookieStore([]byte("secret")) + + req := httptest.NewRequest("GET", "/submit", nil) + req.Header.Set("Referer", "/form") + ctx := context.WithValue(req.Context(), ContextValues("coreData"), &CoreData{}) + req = req.WithContext(ctx) + + h := runHandlerChain(func(w http.ResponseWriter, r *http.Request) error { + return NewUserError("bad input", errors.New("invalid")) + }) + + w := httptest.NewRecorder() + h(w, req) + res := w.Result() + if res.StatusCode != http.StatusSeeOther { + t.Fatalf("expected redirect, got %d", res.StatusCode) + } + loc := res.Header.Get("Location") + if !strings.Contains(loc, "error=bad+input") { + t.Fatalf("redirect missing error param: %s", loc) + } +} + +func TestRunTemplate_BufferedError(t *testing.T) { + SessionName = "testsess" + SessionStore = sessions.NewCookieStore([]byte("secret")) + DBConnectionProvider = "" + + req := httptest.NewRequest("GET", "/", nil) + sess, _ := SessionStore.New(req, SessionName) + sess.Values["GithubUser"] = &User{Login: "user"} + sess.Values["Token"] = &oauth2.Token{} + ctx := context.WithValue(req.Context(), ContextValues("session"), sess) + ctx = context.WithValue(ctx, ContextValues("provider"), "sql") + ctx = context.WithValue(ctx, ContextValues("coreData"), &CoreData{UserRef: "user"}) + req = req.WithContext(ctx) + + w := httptest.NewRecorder() + runTemplate("mainPage.gohtml")(w, req) + + body := w.Body.String() + if !strings.Contains(body, "Database error") { + t.Fatalf("expected database error message, got %q", body) + } + if strings.Count(body, "") != 1 { + t.Fatalf("unexpected partial content: %q", body) + } + if strings.Contains(body, "tab-list") { + t.Fatalf("unexpected partial page content: %q", body) + } +} From a2124ac351a7d92f48893d79bef9d7f8c1f65197 Mon Sep 17 00:00:00 2001 From: Arran Ubels Date: Thu, 29 Jan 2026 22:00:37 +1100 Subject: [PATCH 2/2] refactor: Introduce core interfaces and application structure to improve modularity and embeddability. --- app/app.go | 28 +++++ app/config.go | 23 ++++ authHandlers.go | 24 ++-- autoRefreshPage.go | 6 +- bookmarkActionHandlers.go | 30 ++--- categoryAddHandlers.go | 20 ++-- categoryEditHandlers.go | 14 ++- category_edit_http_test.go | 9 +- cmd/gobookmarks/serve.go | 27 +++-- .../test_verification_template_command.go | 32 ++--- context.go | 18 +++ core.go | 30 ++--- core/coredata.go | 104 +++++++++++++++++ core/interfaces.go | 15 +++ data_embedded_func_override_test.go | 6 +- data_test.go | 24 ++-- devHandlers.go | 8 +- devHandlers_test.go | 3 +- editModeHandlers_test.go | 7 +- funcs.go | 74 ++++++------ moveHandlers.go | 18 +-- pageEditHandlers.go | 22 ++-- provider.go | 33 ++---- provider_access.go | 104 +++++++++++------ provider_git.go | 27 ++--- provider_github.go | 27 +++-- provider_gitlab.go | 29 +++-- provider_sql.go | 31 +++-- router.go | 110 +++++++++++------- router_helpers.go | 49 ++++---- router_helpers_test.go | 11 +- session_test.go | 6 +- signup_git_http_test.go | 7 +- tabEditHandlers.go | 21 ++-- 34 files changed, 641 insertions(+), 356 deletions(-) create mode 100644 app/app.go create mode 100644 app/config.go create mode 100644 context.go create mode 100644 core/coredata.go create mode 100644 core/interfaces.go diff --git a/app/app.go b/app/app.go new file mode 100644 index 0000000..d139268 --- /dev/null +++ b/app/app.go @@ -0,0 +1,28 @@ +package app + +import ( + "database/sql" + + "github.com/arran4/gobookmarks/core" + "github.com/gorilla/sessions" +) + +// App holds the application's long-lived dependencies. +type App struct { + DB *sql.DB + SessionStore sessions.Store + Repo core.Repo + UserProvider core.UserProvider + Config *Config +} + +// NewApp creates a new App instance. +func NewApp(db *sql.DB, store sessions.Store, repo core.Repo, userProvider core.UserProvider, cfg *Config) *App { + return &App{ + DB: db, + SessionStore: store, + Repo: repo, + UserProvider: userProvider, + Config: cfg, + } +} diff --git a/app/config.go b/app/config.go new file mode 100644 index 0000000..d344cc6 --- /dev/null +++ b/app/config.go @@ -0,0 +1,23 @@ +package app + +// Config holds the application configuration. +type Config struct { + SessionName string + BaseURL string + ExternalURL string + DevMode bool + GithubClientID string + GithubSecret string + GithubServer string + GitlabClientID string + GitlabSecret string + GitlabServer string + Title string + CssColumns bool + NoFooter bool + LocalGitPath string + CommitsPerPage int + FaviconCacheDir string + FaviconCacheSize int64 + FaviconMaxCacheCount int +} diff --git a/authHandlers.go b/authHandlers.go index 91c47ea..dcae4d9 100644 --- a/authHandlers.go +++ b/authHandlers.go @@ -4,23 +4,27 @@ import ( "context" "errors" "fmt" + "log" + "net/http" + + "github.com/arran4/gobookmarks/core" "github.com/gorilla/mux" "github.com/gorilla/sessions" "golang.org/x/oauth2" - "log" - "net/http" ) func UserLogoutAction(w http.ResponseWriter, r *http.Request) error { type Data struct { - *CoreData + *core.CoreData } data := Data{ - CoreData: r.Context().Value(ContextValues("coreData")).(*CoreData), + CoreData: r.Context().Value(core.ContextValues("coreData")).(*core.CoreData), } - session := r.Context().Value(ContextValues("session")).(*sessions.Session) + // Use the Core interface to get the session + cc := r.Context().Value(core.ContextValues("coreData")).(core.Core) + session := cc.GetSession() delete(session.Values, "GithubUser") delete(session.Values, "Token") delete(session.Values, "Provider") @@ -107,7 +111,7 @@ func LoginWithProvider(w http.ResponseWriter, r *http.Request) error { func Oauth2CallbackPage(w http.ResponseWriter, r *http.Request) error { type ErrorData struct { - *CoreData + *core.CoreData Error string } @@ -160,7 +164,7 @@ func Oauth2CallbackPage(w http.ResponseWriter, r *http.Request) error { return fmt.Errorf("user lookup error: %w", err) } - if err := ensureRepo(r.Context(), p, user.Login, token); err != nil { + if err := ensureRepo(r.Context(), p, user.GetLogin(), token); err != nil { // expire the session from the login step session.Options.MaxAge = -1 _ = session.Save(r, w) @@ -204,7 +208,7 @@ func GitLoginAction(w http.ResponseWriter, r *http.Request) error { return nil } session.Values["Provider"] = "git" - session.Values["GithubUser"] = &User{Login: user} + session.Values["GithubUser"] = &core.BasicUser{Login: user} session.Values["Token"] = nil session.Values["version"] = version if err := session.Save(r, w); err != nil { @@ -270,7 +274,7 @@ func SqlLoginAction(w http.ResponseWriter, r *http.Request) error { return nil } session.Values["Provider"] = "sql" - session.Values["GithubUser"] = &User{Login: user} + session.Values["GithubUser"] = &core.BasicUser{Login: user} session.Values["Token"] = nil session.Values["version"] = version if err := session.Save(r, w); err != nil { @@ -320,7 +324,7 @@ func UserAdderMiddleware(next http.Handler) http.Handler { log.Printf("session error: %v", err) } - ctx := context.WithValue(request.Context(), ContextValues("session"), session) + ctx := context.WithValue(request.Context(), core.ContextValues("session"), session) next.ServeHTTP(writer, request.WithContext(ctx)) }) } diff --git a/autoRefreshPage.go b/autoRefreshPage.go index 6379d10..760bb88 100644 --- a/autoRefreshPage.go +++ b/autoRefreshPage.go @@ -3,16 +3,18 @@ package gobookmarks import ( "fmt" "net/http" + + "github.com/arran4/gobookmarks/core" ) func TaskDoneAutoRefreshPage(w http.ResponseWriter, r *http.Request) error { type Data struct { - *CoreData + *core.CoreData Error string } data := Data{ - CoreData: r.Context().Value(ContextValues("coreData")).(*CoreData), + CoreData: r.Context().Value(core.ContextValues("coreData")).(*core.CoreData), Error: r.URL.Query().Get("error"), } diff --git a/bookmarkActionHandlers.go b/bookmarkActionHandlers.go index 911f561..9d0e4bb 100644 --- a/bookmarkActionHandlers.go +++ b/bookmarkActionHandlers.go @@ -3,16 +3,18 @@ package gobookmarks import ( "errors" "fmt" - "github.com/gorilla/sessions" - "golang.org/x/oauth2" "net/http" "strconv" + + "github.com/arran4/gobookmarks/core" + "github.com/gorilla/sessions" + "golang.org/x/oauth2" ) func BookmarksEditSaveAction(w http.ResponseWriter, r *http.Request) error { text := r.PostFormValue("text") - session := r.Context().Value(ContextValues("session")).(*sessions.Session) - githubUser, _ := session.Values["GithubUser"].(*User) + session := GetCore(r.Context()).GetSession() + githubUser, _ := session.Values["GithubUser"].(*core.BasicUser) token, _ := session.Values["Token"].(*oauth2.Token) branch := r.PostFormValue("branch") ref := r.PostFormValue("ref") @@ -57,8 +59,8 @@ func BookmarksEditSaveAction(w http.ResponseWriter, r *http.Request) error { func BookmarksEditCreateAction(w http.ResponseWriter, r *http.Request) error { text := r.PostFormValue("text") - session := r.Context().Value(ContextValues("session")).(*sessions.Session) - githubUser, _ := session.Values["GithubUser"].(*User) + session := r.Context().Value(core.ContextValues("session")).(*sessions.Session) + githubUser, _ := session.Values["GithubUser"].(*core.BasicUser) token, _ := session.Values["Token"].(*oauth2.Token) branch := r.PostFormValue("branch") @@ -80,8 +82,8 @@ func CategoryEditSaveAction(w http.ResponseWriter, r *http.Request) error { if err != nil { return fmt.Errorf("invalid index: %w", err) } - session := r.Context().Value(ContextValues("session")).(*sessions.Session) - githubUser, _ := session.Values["GithubUser"].(*User) + session := r.Context().Value(core.ContextValues("session")).(*sessions.Session) + githubUser, _ := session.Values["GithubUser"].(*core.BasicUser) token, _ := session.Values["Token"].(*oauth2.Token) branch := r.PostFormValue("branch") ref := r.PostFormValue("ref") @@ -126,8 +128,8 @@ func CategoryMoveBeforeAction(w http.ResponseWriter, r *http.Request) error { return fmt.Errorf("invalid to index: %w", err) } - session := r.Context().Value(ContextValues("session")).(*sessions.Session) - githubUser, _ := session.Values["GithubUser"].(*User) + session := r.Context().Value(core.ContextValues("session")).(*sessions.Session) + githubUser, _ := session.Values["GithubUser"].(*core.BasicUser) token, _ := session.Values["Token"].(*oauth2.Token) login := "" @@ -178,8 +180,8 @@ func CategoryMoveEndAction(w http.ResponseWriter, r *http.Request) error { destCol = -1 } - session := r.Context().Value(ContextValues("session")).(*sessions.Session) - githubUser, _ := session.Values["GithubUser"].(*User) + session := r.Context().Value(core.ContextValues("session")).(*sessions.Session) + githubUser, _ := session.Values["GithubUser"].(*core.BasicUser) token, _ := session.Values["Token"].(*oauth2.Token) login := "" @@ -237,8 +239,8 @@ func CategoryMoveNewColumnAction(w http.ResponseWriter, r *http.Request) error { } destCol, _ := strconv.Atoi(destColStr) - session := r.Context().Value(ContextValues("session")).(*sessions.Session) - githubUser, _ := session.Values["GithubUser"].(*User) + session := r.Context().Value(core.ContextValues("session")).(*sessions.Session) + githubUser, _ := session.Values["GithubUser"].(*core.BasicUser) token, _ := session.Values["Token"].(*oauth2.Token) login := "" diff --git a/categoryAddHandlers.go b/categoryAddHandlers.go index 895d310..c256d0f 100644 --- a/categoryAddHandlers.go +++ b/categoryAddHandlers.go @@ -2,15 +2,19 @@ package gobookmarks import ( "fmt" - "github.com/gorilla/sessions" - "golang.org/x/oauth2" "net/http" "strconv" + + "github.com/arran4/gobookmarks/core" + "github.com/gorilla/sessions" + "golang.org/x/oauth2" ) func AddCategoryPage(w http.ResponseWriter, r *http.Request) error { - session := r.Context().Value(ContextValues("session")).(*sessions.Session) - githubUser, _ := session.Values["GithubUser"].(*User) + w.Header().Set("Content-Type", "text/html; charset=utf-8") + + session := GetCore(r.Context()).GetSession() + githubUser, _ := session.Values["GithubUser"].(*core.BasicUser) token, _ := session.Values["Token"].(*oauth2.Token) ref := r.URL.Query().Get("ref") @@ -27,14 +31,14 @@ func AddCategoryPage(w http.ResponseWriter, r *http.Request) error { col, _ := strconv.Atoi(r.URL.Query().Get("col")) data := struct { - *CoreData + *core.CoreData Error string Index int Text string Sha string Col int }{ - CoreData: r.Context().Value(ContextValues("coreData")).(*CoreData), + CoreData: r.Context().Value(core.ContextValues("coreData")).(*core.CoreData), Error: r.URL.Query().Get("error"), Index: -1, Text: "Category: ", @@ -57,8 +61,8 @@ func CategoryAddSaveAction(w http.ResponseWriter, r *http.Request) error { pageIdx, _ := strconv.Atoi(r.PostFormValue("page")) colIdx, _ := strconv.Atoi(r.PostFormValue("col")) - session := r.Context().Value(ContextValues("session")).(*sessions.Session) - githubUser, _ := session.Values["GithubUser"].(*User) + session := r.Context().Value(core.ContextValues("session")).(*sessions.Session) + githubUser, _ := session.Values["GithubUser"].(*core.BasicUser) token, _ := session.Values["Token"].(*oauth2.Token) login := "" diff --git a/categoryEditHandlers.go b/categoryEditHandlers.go index a0dce2d..acdfa79 100644 --- a/categoryEditHandlers.go +++ b/categoryEditHandlers.go @@ -2,10 +2,12 @@ package gobookmarks import ( "fmt" - "github.com/gorilla/sessions" - "golang.org/x/oauth2" "net/http" "strconv" + + "github.com/arran4/gobookmarks/core" + "github.com/gorilla/sessions" + "golang.org/x/oauth2" ) func EditCategoryPage(w http.ResponseWriter, r *http.Request) error { @@ -14,8 +16,8 @@ func EditCategoryPage(w http.ResponseWriter, r *http.Request) error { if err != nil { return fmt.Errorf("invalid index: %w", err) } - session := r.Context().Value(ContextValues("session")).(*sessions.Session) - githubUser, _ := session.Values["GithubUser"].(*User) + session := r.Context().Value(core.ContextValues("session")).(*sessions.Session) + githubUser, _ := session.Values["GithubUser"].(*core.BasicUser) token, _ := session.Values["Token"].(*oauth2.Token) ref := r.URL.Query().Get("ref") @@ -36,14 +38,14 @@ func EditCategoryPage(w http.ResponseWriter, r *http.Request) error { col, _ := strconv.Atoi(r.URL.Query().Get("col")) data := struct { - *CoreData + *core.CoreData Error string Index int Text string Sha string Col int }{ - CoreData: r.Context().Value(ContextValues("coreData")).(*CoreData), + CoreData: r.Context().Value(core.ContextValues("coreData")).(*core.CoreData), Error: r.URL.Query().Get("error"), Index: idx, Text: text, diff --git a/category_edit_http_test.go b/category_edit_http_test.go index 72893f1..643df3a 100644 --- a/category_edit_http_test.go +++ b/category_edit_http_test.go @@ -7,6 +7,7 @@ import ( "strings" "testing" + "github.com/arran4/gobookmarks/core" "github.com/gorilla/sessions" "golang.org/x/oauth2" ) @@ -26,11 +27,11 @@ func setupCategoryEditTest(t *testing.T) (GitProvider, string, *sessions.Session if err != nil { t.Fatalf("getSession: %v", err) } - sess.Values["GithubUser"] = &User{Login: user} + sess.Values["GithubUser"] = &core.BasicUser{Login: user} sess.Values["Token"] = &oauth2.Token{} - ctx := context.WithValue(sessReq.Context(), ContextValues("session"), sess) - ctx = context.WithValue(ctx, ContextValues("provider"), "git") - ctx = context.WithValue(ctx, ContextValues("coreData"), &CoreData{}) + ctx := context.WithValue(sessReq.Context(), core.ContextValues("session"), sess) + ctx = context.WithValue(ctx, core.ContextValues("provider"), "git") + ctx = context.WithValue(ctx, core.ContextValues("coreData"), &core.CoreData{Session: sess}) return p, user, sess, ctx } diff --git a/cmd/gobookmarks/serve.go b/cmd/gobookmarks/serve.go index e3853e3..a411efe 100644 --- a/cmd/gobookmarks/serve.go +++ b/cmd/gobookmarks/serve.go @@ -18,6 +18,7 @@ import ( "time" . "github.com/arran4/gobookmarks" + "github.com/arran4/gobookmarks/app" "github.com/gorilla/securecookie" "github.com/gorilla/sessions" ) @@ -251,16 +252,26 @@ func (c *ServeCommand) Execute(args []string) error { return errors.New("no providers available") } - // Create RouterConfig - routerCfg := &RouterConfig{ - SessionStore: SessionStore, // Globals should be initialized by now (lines 248-257) - SessionName: SessionName, - ExternalURL: cfg.ExternalURL, - BaseURL: "", // Root - DevMode: *cfg.DevMode, + // Create App + appCfg := &app.Config{ + SessionName: SessionName, + ExternalURL: cfg.ExternalURL, + BaseURL: "", // Root + DevMode: *cfg.DevMode, } - r := NewRouter(routerCfg) + // For standalone serve, we don't fix the Repo in App usually, providing nil allows fallback to session-based provider. + // However, if we wanted to enforce SQL we could. Current behavior is dynamic. + // We need to pass DB and Store. + + // OpenDB if configured + db, _ := OpenDB() // Errors logged inside or handled? OpenDB returns error. + // serve.go L43 setup DB config. + // OpenDB uses DBConnectionProvider/String globals. + + application := app.NewApp(db, SessionStore, nil, nil, appCfg) + + r := NewRouter(application) http.Handle("/", r) diff --git a/cmd/gobookmarks/test_verification_template_command.go b/cmd/gobookmarks/test_verification_template_command.go index 3510c78..745095e 100644 --- a/cmd/gobookmarks/test_verification_template_command.go +++ b/cmd/gobookmarks/test_verification_template_command.go @@ -11,6 +11,7 @@ import ( "os" . "github.com/arran4/gobookmarks" + "github.com/arran4/gobookmarks/core" ) type TemplateCommand struct { @@ -69,7 +70,7 @@ func (c *TemplateCommand) Execute(args []string) error { return c.HelpCmd.Execute(remaining[1:]) } - coreData := &CoreData{ + coreData := &core.CoreData{ Title: "Test Verification", UserRef: "testuser", EditMode: false, @@ -154,7 +155,7 @@ https://stackoverflow.com Stack Overflow // However, some funcs might need session or other context values. // For "useCssColumns" etc. - ctx := context.WithValue(req.Context(), ContextValues("coreData"), coreData) + ctx := context.WithValue(req.Context(), core.ContextValues("coreData"), coreData) req = req.WithContext(ctx) // Create funcs that override the default behavior to return our static data @@ -188,16 +189,16 @@ https://stackoverflow.com Stack Overflow } if indexName != "" { href := TabHref(i, "") // No ref in static mode - lastSha := "" // No SHA in static mode + lastSha := "" // No SHA in static mode if len(t.Pages) > 0 { lastSha = t.Pages[len(t.Pages)-1].Sha() } tabs = append(tabs, TabInfo{ - Index: i, - Name: t.Name, - IndexName: indexName, - Href: href, - EditHref: AppendQueryParams(href, "edit", "1"), + Index: i, + Name: t.Name, + IndexName: indexName, + Href: href, + EditHref: AppendQueryParams(href, "edit", "1"), LastPageSha: lastSha, }) } @@ -220,14 +221,14 @@ https://stackoverflow.com Stack Overflow } tabs = append(tabs, TabWithPages{ TabInfo: TabInfo{ - Index: i, - Name: t.Name, - IndexName: indexName, - Href: href, - EditHref: AppendQueryParams(href, "edit", "1"), + Index: i, + Name: t.Name, + IndexName: indexName, + Href: href, + EditHref: AppendQueryParams(href, "edit", "1"), LastPageSha: lastSha, }, - Pages: t.Pages, + Pages: t.Pages, }) } } @@ -252,12 +253,11 @@ https://stackoverflow.com Stack Overflow funcs["loggedIn"] = func() (bool, error) { return true, nil } funcs["showPages"] = func() bool { return true } - // Compile templates with our modified funcs tmpl := GetCompiledTemplates(funcs) type Data struct { - *CoreData + *core.CoreData Error string } data := Data{ diff --git a/context.go b/context.go new file mode 100644 index 0000000..850fbfb --- /dev/null +++ b/context.go @@ -0,0 +1,18 @@ +package gobookmarks + +import ( + "context" + + "github.com/arran4/gobookmarks/core" +) + +// GetCore retrieves the Core interface from the context. +// It returns nil if not found or if the value does not implement Core. +func GetCore(ctx context.Context) core.Core { + if v := ctx.Value(core.ContextValues("coreData")); v != nil { + if c, ok := v.(core.Core); ok { + return c + } + } + return nil +} diff --git a/core.go b/core.go index ac549cb..f8df739 100644 --- a/core.go +++ b/core.go @@ -4,23 +4,26 @@ import ( "bufio" "context" "encoding/gob" - "github.com/gorilla/sessions" - "golang.org/x/oauth2" "log" "net/http" "os" "strings" + + "github.com/gorilla/sessions" + "golang.org/x/oauth2" + + "github.com/arran4/gobookmarks/core" ) func init() { - gob.Register(&User{}) + gob.Register(&core.BasicUser{}) gob.Register(&oauth2.Token{}) } func CoreAdderMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { - session := request.Context().Value(ContextValues("session")).(*sessions.Session) - githubUser, _ := session.Values["GithubUser"].(*User) + session := request.Context().Value(core.ContextValues("session")).(*sessions.Session) + githubUser, _ := session.Values["GithubUser"].(*core.BasicUser) providerName, _ := session.Values["Provider"].(string) login := "" @@ -36,29 +39,20 @@ func CoreAdderMiddleware(next http.Handler) http.Handler { title = "dev: " + title } - ctx := context.WithValue(request.Context(), ContextValues("provider"), providerName) + ctx := context.WithValue(request.Context(), core.ContextValues("provider"), providerName) editMode := request.URL.Query().Get("edit") == "1" tab := TabFromRequest(request) - ctx = context.WithValue(ctx, ContextValues("coreData"), &CoreData{ + ctx = context.WithValue(ctx, core.ContextValues("coreData"), &core.CoreData{ UserRef: login, Title: title, EditMode: editMode, Tab: tab, - requestCache: &requestCache{data: make(map[string]*bookmarkCacheEntry)}, + RequestCache: &core.RequestCache{Data: make(map[string]*core.BookmarkCacheEntry)}, }) next.ServeHTTP(writer, request.WithContext(ctx)) }) } -type CoreData struct { - Title string - AutoRefresh bool - UserRef string - EditMode bool - Tab int - requestCache *requestCache -} - type Configuration struct { data map[string]string } @@ -101,5 +95,3 @@ func (c *Configuration) readConfiguration(filename string) { } } } - -type ContextValues string diff --git a/core/coredata.go b/core/coredata.go new file mode 100644 index 0000000..dd7b3b1 --- /dev/null +++ b/core/coredata.go @@ -0,0 +1,104 @@ +package core + +import ( + "context" + "net/http" + "sync" + "time" + + "github.com/gorilla/sessions" + "golang.org/x/oauth2" +) + +// CoreData represents the request-scoped context data. +type CoreData struct { + User User + Repo Repo + EditMode bool + Tab int + Title string + AutoRefresh bool + UserRef string // For backward compatibility if needed, or derived from User + Session *sessions.Session + // Add other request-scoped fields as needed + RequestCache *RequestCache +} + +func (cd *CoreData) GetSession() *sessions.Session { + return cd.Session +} + +type RequestCache struct { + sync.RWMutex + Data map[string]*BookmarkCacheEntry +} + +type BookmarkCacheEntry struct { + Bookmarks string + SHA string + Expiry time.Time +} + +// BasicUser represents the authenticated user. +type BasicUser struct { + Login string +} + +func (u *BasicUser) GetLogin() string { + return u.Login +} + +func (cd *CoreData) GetUser() User { + // If User is nil, we return nil (or typed nil which is fine as interface value usually checks nil) + // However, usually we return the interface. + if cd.User == nil { + return nil + } + return cd.User +} + +// UserProvider defines the interface for external user management +type UserProvider interface { + // CurrentUser returns the current user from the request context + CurrentUser(r *http.Request) (User, error) + IsLoggedIn(r *http.Request) bool +} + +// Repo defines the interface for data access. +type Repo interface { + GetBookmarks(ctx context.Context, user, ref string, token *oauth2.Token) (string, string, error) + UpdateBookmarks(ctx context.Context, user string, token *oauth2.Token, sourceRef, branch, text, expectSHA string) error + CreateBookmarks(ctx context.Context, user string, token *oauth2.Token, branch, text string) error + RepoExists(ctx context.Context, user string, token *oauth2.Token, name string) (bool, error) + CreateRepo(ctx context.Context, user string, token *oauth2.Token, name string) error + CreateUser(ctx context.Context, user, password string) error + CheckPassword(ctx context.Context, user, password string) (bool, error) + GetTags(ctx context.Context, user string, token *oauth2.Token) ([]*Tag, error) + GetBranches(ctx context.Context, user string, token *oauth2.Token) ([]*Branch, error) + GetCommits(ctx context.Context, user string, token *oauth2.Token, ref string, page, perPage int) ([]*Commit, error) + AdjacentCommits(ctx context.Context, user string, token *oauth2.Token, ref, sha string) (string, string, error) +} + +// Type definitions for params (moved from top level or redefined) +type Branch struct { + Name string +} + +type Commit struct { + SHA string + Message string + CommitterName string + CommitterEmail string + CommitterDate time.Time +} + +type Tag struct { + Name string +} + +// SessionManager abstract session operations if needed +type SessionManager interface { + Get(r any, name string) (*sessions.Session, error) +} + +type ContextValues string diff --git a/core/interfaces.go b/core/interfaces.go new file mode 100644 index 0000000..dfa2d0a --- /dev/null +++ b/core/interfaces.go @@ -0,0 +1,15 @@ +package core + +import "github.com/gorilla/sessions" + +// Core defines the interface for the core data accessible to handlers. +// This allows gobookmarks to be embedded in other applications (like goa4web) +// that validly implement this interface. +type Core interface { + GetSession() *sessions.Session + GetUser() User +} + +type User interface { + GetLogin() string +} diff --git a/data_embedded_func_override_test.go b/data_embedded_func_override_test.go index dcd1e62..ab75fca 100644 --- a/data_embedded_func_override_test.go +++ b/data_embedded_func_override_test.go @@ -7,6 +7,8 @@ import ( "bytes" "strings" "testing" + + "github.com/arran4/gobookmarks/core" ) func TestGetCompiledTemplates_FuncOverride(t *testing.T) { @@ -27,10 +29,10 @@ func TestGetCompiledTemplates_FuncOverride(t *testing.T) { // "loginPage.gohtml" typically uses {{ version }} in the footer var buf bytes.Buffer data := struct { - *CoreData + *core.CoreData Error string }{ - CoreData: &CoreData{Title: "Test", UserRef: "user"}, + CoreData: &core.CoreData{Title: "Test", UserRef: "user"}, } if err := tmpl.ExecuteTemplate(&buf, "loginPage.gohtml", data); err != nil { diff --git a/data_test.go b/data_test.go index 54d4c56..cf67634 100644 --- a/data_test.go +++ b/data_test.go @@ -8,6 +8,8 @@ import ( "strings" "testing" "time" + + "github.com/arran4/gobookmarks/core" ) func TestCompileGoHTML(t *testing.T) { @@ -141,14 +143,14 @@ func testFuncMap() template.FuncMap { "bookmarksExist": func() (bool, error) { return true, nil }, "bookmarksSHA": func() (string, error) { return "sha", nil }, "branchOrEditBranch": func() (string, error) { return "main", nil }, - "tags": func() ([]*Tag, error) { - return []*Tag{{Name: "v1"}}, nil + "tags": func() ([]*core.Tag, error) { + return []*core.Tag{{Name: "v1"}}, nil }, - "branches": func() ([]*Branch, error) { - return []*Branch{{Name: "main"}}, nil + "branches": func() ([]*core.Branch, error) { + return []*core.Branch{{Name: "main"}}, nil }, - "commits": func() ([]*Commit, error) { - return []*Commit{{ + "commits": func() ([]*core.Commit, error) { + return []*core.Commit{{ SHA: "abc", Message: "msg", CommitterName: "dev", @@ -176,14 +178,14 @@ func TestExecuteTemplates(t *testing.T) { t.Fatalf("template parse error: %v", err) } baseData := struct { - *CoreData + *core.CoreData Error string }{ - CoreData: &CoreData{Title: "Test", UserRef: "user"}, + CoreData: &core.CoreData{Title: "Test", UserRef: "user"}, } catData := struct { - *CoreData + *core.CoreData Error string Index int Text string @@ -198,7 +200,7 @@ func TestExecuteTemplates(t *testing.T) { } pageData := struct { - *CoreData + *core.CoreData Error string Name string Text string @@ -225,7 +227,7 @@ func TestExecuteTemplates(t *testing.T) { {"historyCommits", "historyCommits.gohtml", baseData}, {"taskDone", "taskDoneAutoRefreshPage.gohtml", baseData}, {"error", "error.gohtml", struct { - *CoreData + *core.CoreData Error string }{baseData.CoreData, "boom"}}, } diff --git a/devHandlers.go b/devHandlers.go index d308f70..8850a66 100644 --- a/devHandlers.go +++ b/devHandlers.go @@ -2,13 +2,15 @@ package gobookmarks import ( "fmt" - "github.com/gorilla/sessions" "net/http" + + "github.com/arran4/gobookmarks/core" + "github.com/gorilla/sessions" ) // EnableCssColumnsAction stores a session flag to use CSS column layout. func EnableCssColumnsAction(w http.ResponseWriter, r *http.Request) error { - session := r.Context().Value(ContextValues("session")).(*sessions.Session) + session := r.Context().Value(core.ContextValues("session")).(*sessions.Session) session.Values["useCssColumns"] = true if err := session.Save(r, w); err != nil { return fmt.Errorf("session save: %w", err) @@ -18,7 +20,7 @@ func EnableCssColumnsAction(w http.ResponseWriter, r *http.Request) error { // DisableCssColumnsAction stores a session flag to use table layout. func DisableCssColumnsAction(w http.ResponseWriter, r *http.Request) error { - session := r.Context().Value(ContextValues("session")).(*sessions.Session) + session := r.Context().Value(core.ContextValues("session")).(*sessions.Session) session.Values["useCssColumns"] = false if err := session.Save(r, w); err != nil { return fmt.Errorf("session save: %w", err) diff --git a/devHandlers_test.go b/devHandlers_test.go index 3bb1ff4..e04b822 100644 --- a/devHandlers_test.go +++ b/devHandlers_test.go @@ -5,6 +5,7 @@ import ( "net/http/httptest" "testing" + "github.com/arran4/gobookmarks/core" "github.com/gorilla/sessions" ) @@ -18,7 +19,7 @@ func TestCssColumnToggle(t *testing.T) { if err != nil { t.Fatalf("getSession: %v", err) } - ctx := context.WithValue(req.Context(), ContextValues("session"), session) + ctx := context.WithValue(req.Context(), core.ContextValues("session"), session) req = req.WithContext(ctx) w = httptest.NewRecorder() diff --git a/editModeHandlers_test.go b/editModeHandlers_test.go index ff2efe3..4962e1a 100644 --- a/editModeHandlers_test.go +++ b/editModeHandlers_test.go @@ -6,6 +6,7 @@ import ( "net/http/httptest" "testing" + "github.com/arran4/gobookmarks/core" "github.com/gorilla/sessions" ) @@ -19,7 +20,7 @@ func TestEditModeToggle(t *testing.T) { if err != nil { t.Fatalf("getSession: %v", err) } - ctx := context.WithValue(req.Context(), ContextValues("session"), session) + ctx := context.WithValue(req.Context(), core.ContextValues("session"), session) req = req.WithContext(ctx) // enable edit mode @@ -31,9 +32,9 @@ func TestEditModeToggle(t *testing.T) { t.Fatalf("edit mode query not set") } - var cd *CoreData + var cd *core.CoreData handler := CoreAdderMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - cd = r.Context().Value(ContextValues("coreData")).(*CoreData) + cd = r.Context().Value(core.ContextValues("coreData")).(*core.CoreData) })) w = httptest.NewRecorder() handler.ServeHTTP(w, req) diff --git a/funcs.go b/funcs.go index cf9ecc1..b96e3d2 100644 --- a/funcs.go +++ b/funcs.go @@ -3,13 +3,15 @@ package gobookmarks import ( "errors" "fmt" - "github.com/gorilla/sessions" - "golang.org/x/oauth2" "html/template" "net/http" "strconv" "strings" "time" + + "github.com/arran4/gobookmarks/core" + "github.com/gorilla/sessions" + "golang.org/x/oauth2" ) // TabInfo is used by templates to display tab navigation with indexes. @@ -129,7 +131,7 @@ func NewFuncs(r *http.Request) template.FuncMap { return i }, "useCssColumns": func() bool { - sessioni := r.Context().Value(ContextValues("session")) + sessioni := r.Context().Value(core.ContextValues("session")) if session, ok := sessioni.(*sessions.Session); ok && session != nil { if v, ok := session.Values["useCssColumns"].(bool); ok { return v @@ -150,20 +152,20 @@ func NewFuncs(r *http.Request) template.FuncMap { if strings.HasPrefix(r.URL.Path, "/login") || r.URL.Path == "/status" { return false } - sessioni := r.Context().Value(ContextValues("session")) + sessioni := r.Context().Value(core.ContextValues("session")) session, ok := sessioni.(*sessions.Session) if !ok || session == nil { return false } - githubUser, ok := session.Values["GithubUser"].(*User) + githubUser, ok := session.Values["GithubUser"].(*core.BasicUser) if !ok || githubUser == nil { return false } return true }, "loggedIn": func() (bool, error) { - session := r.Context().Value(ContextValues("session")).(*sessions.Session) - githubUser, ok := session.Values["GithubUser"].(*User) + session := r.Context().Value(core.ContextValues("session")).(*sessions.Session) + githubUser, ok := session.Values["GithubUser"].(*core.BasicUser) return ok && githubUser != nil, nil }, "bookmarks": func() (string, error) { @@ -179,8 +181,8 @@ func NewFuncs(r *http.Request) template.FuncMap { return BookmarksExist(r) }, "bookmarksSHA": func() (string, error) { - session := r.Context().Value(ContextValues("session")).(*sessions.Session) - githubUser, _ := session.Values["GithubUser"].(*User) + session := r.Context().Value(core.ContextValues("session")).(*sessions.Session) + githubUser, _ := session.Values["GithubUser"].(*core.BasicUser) token, _ := session.Values["Token"].(*oauth2.Token) login := "" if githubUser != nil { @@ -210,8 +212,8 @@ func NewFuncs(r *http.Request) template.FuncMap { return "main", nil }, "bookmarkPages": func() ([]*BookmarkPage, error) { - session := r.Context().Value(ContextValues("session")).(*sessions.Session) - githubUser, _ := session.Values["GithubUser"].(*User) + session := r.Context().Value(core.ContextValues("session")).(*sessions.Session) + githubUser, _ := session.Values["GithubUser"].(*core.BasicUser) token, _ := session.Values["Token"].(*oauth2.Token) login := "" @@ -239,8 +241,8 @@ func NewFuncs(r *http.Request) template.FuncMap { return tabs[idx].Pages, nil }, "bookmarkTabs": func() ([]TabInfo, error) { - session := r.Context().Value(ContextValues("session")).(*sessions.Session) - githubUser, _ := session.Values["GithubUser"].(*User) + session := r.Context().Value(core.ContextValues("session")).(*sessions.Session) + githubUser, _ := session.Values["GithubUser"].(*core.BasicUser) token, _ := session.Values["Token"].(*oauth2.Token) login := "" @@ -279,8 +281,8 @@ func NewFuncs(r *http.Request) template.FuncMap { return tabs, nil }, "bookmarkTabsWithPages": func() ([]TabWithPages, error) { - session := r.Context().Value(ContextValues("session")).(*sessions.Session) - githubUser, _ := session.Values["GithubUser"].(*User) + session := r.Context().Value(core.ContextValues("session")).(*sessions.Session) + githubUser, _ := session.Values["GithubUser"].(*core.BasicUser) token, _ := session.Values["Token"].(*oauth2.Token) login := "" @@ -322,8 +324,8 @@ func NewFuncs(r *http.Request) template.FuncMap { return tabs, nil }, "tabName": func() string { - session := r.Context().Value(ContextValues("session")).(*sessions.Session) - githubUser, _ := session.Values["GithubUser"].(*User) + session := r.Context().Value(core.ContextValues("session")).(*sessions.Session) + githubUser, _ := session.Values["GithubUser"].(*core.BasicUser) token, _ := session.Values["Token"].(*oauth2.Token) login := "" @@ -354,8 +356,8 @@ func NewFuncs(r *http.Request) template.FuncMap { return name }, "bookmarkColumns": func() ([]*BookmarkColumn, error) { - session := r.Context().Value(ContextValues("session")).(*sessions.Session) - githubUser, _ := session.Values["GithubUser"].(*User) + session := r.Context().Value(core.ContextValues("session")).(*sessions.Session) + githubUser, _ := session.Values["GithubUser"].(*core.BasicUser) token, _ := session.Values["Token"].(*oauth2.Token) login := "" @@ -385,9 +387,9 @@ func NewFuncs(r *http.Request) template.FuncMap { } return columns, nil }, - "tags": func() ([]*Tag, error) { - session := r.Context().Value(ContextValues("session")).(*sessions.Session) - githubUser, _ := session.Values["GithubUser"].(*User) + "tags": func() ([]*core.Tag, error) { + session := r.Context().Value(core.ContextValues("session")).(*sessions.Session) + githubUser, _ := session.Values["GithubUser"].(*core.BasicUser) token, _ := session.Values["Token"].(*oauth2.Token) login := "" @@ -401,9 +403,9 @@ func NewFuncs(r *http.Request) template.FuncMap { } return tags, nil }, - "branches": func() ([]*Branch, error) { - session := r.Context().Value(ContextValues("session")).(*sessions.Session) - githubUser, _ := session.Values["GithubUser"].(*User) + "branches": func() ([]*core.Branch, error) { + session := r.Context().Value(core.ContextValues("session")).(*sessions.Session) + githubUser, _ := session.Values["GithubUser"].(*core.BasicUser) token, _ := session.Values["Token"].(*oauth2.Token) login := "" @@ -416,9 +418,9 @@ func NewFuncs(r *http.Request) template.FuncMap { } return branches, nil }, - "commits": func() ([]*Commit, error) { - session := r.Context().Value(ContextValues("session")).(*sessions.Session) - githubUser, _ := session.Values["GithubUser"].(*User) + "commits": func() ([]*core.Commit, error) { + session := r.Context().Value(core.ContextValues("session")).(*sessions.Session) + githubUser, _ := session.Values["GithubUser"].(*core.BasicUser) token, _ := session.Values["Token"].(*oauth2.Token) login := "" @@ -437,8 +439,8 @@ func NewFuncs(r *http.Request) template.FuncMap { return commits, nil }, "prevCommit": func() string { - session := r.Context().Value(ContextValues("session")).(*sessions.Session) - githubUser, _ := session.Values["GithubUser"].(*User) + session := r.Context().Value(core.ContextValues("session")).(*sessions.Session) + githubUser, _ := session.Values["GithubUser"].(*core.BasicUser) token, _ := session.Values["Token"].(*oauth2.Token) login := "" @@ -457,8 +459,8 @@ func NewFuncs(r *http.Request) template.FuncMap { return prev }, "nextCommit": func() string { - session := r.Context().Value(ContextValues("session")).(*sessions.Session) - githubUser, _ := session.Values["GithubUser"].(*User) + session := r.Context().Value(core.ContextValues("session")).(*sessions.Session) + githubUser, _ := session.Values["GithubUser"].(*core.BasicUser) token, _ := session.Values["Token"].(*oauth2.Token) login := "" @@ -489,8 +491,8 @@ func NewFuncs(r *http.Request) template.FuncMap { } func Bookmarks(r *http.Request) (string, error) { - session := r.Context().Value(ContextValues("session")).(*sessions.Session) - githubUser, _ := session.Values["GithubUser"].(*User) + session := r.Context().Value(core.ContextValues("session")).(*sessions.Session) + githubUser, _ := session.Values["GithubUser"].(*core.BasicUser) token, _ := session.Values["Token"].(*oauth2.Token) ref := r.URL.Query().Get("ref") @@ -510,8 +512,8 @@ func Bookmarks(r *http.Request) (string, error) { } func BookmarksExist(r *http.Request) (bool, error) { - session := r.Context().Value(ContextValues("session")).(*sessions.Session) - githubUser, _ := session.Values["GithubUser"].(*User) + session := r.Context().Value(core.ContextValues("session")).(*sessions.Session) + githubUser, _ := session.Values["GithubUser"].(*core.BasicUser) token, _ := session.Values["Token"].(*oauth2.Token) ref := r.URL.Query().Get("ref") diff --git a/moveHandlers.go b/moveHandlers.go index d36a9b2..edc6882 100644 --- a/moveHandlers.go +++ b/moveHandlers.go @@ -2,17 +2,19 @@ package gobookmarks import ( "fmt" - "github.com/gorilla/sessions" - "golang.org/x/oauth2" "net/http" "strconv" + + "github.com/arran4/gobookmarks/core" + "github.com/gorilla/sessions" + "golang.org/x/oauth2" ) func MoveTabAction(w http.ResponseWriter, r *http.Request) error { from, _ := strconv.Atoi(r.URL.Query().Get("from")) to, _ := strconv.Atoi(r.URL.Query().Get("to")) - session := r.Context().Value(ContextValues("session")).(*sessions.Session) - githubUser, _ := session.Values["GithubUser"].(*User) + session := r.Context().Value(core.ContextValues("session")).(*sessions.Session) + githubUser, _ := session.Values["GithubUser"].(*core.BasicUser) token, _ := session.Values["Token"].(*oauth2.Token) login := "" if githubUser != nil { @@ -38,8 +40,8 @@ func MovePageAction(w http.ResponseWriter, r *http.Request) error { from, _ := strconv.Atoi(r.URL.Query().Get("from")) to, _ := strconv.Atoi(r.URL.Query().Get("to")) tabIdx := TabFromRequest(r) - session := r.Context().Value(ContextValues("session")).(*sessions.Session) - githubUser, _ := session.Values["GithubUser"].(*User) + session := r.Context().Value(core.ContextValues("session")).(*sessions.Session) + githubUser, _ := session.Values["GithubUser"].(*core.BasicUser) token, _ := session.Values["Token"].(*oauth2.Token) login := "" if githubUser != nil { @@ -69,8 +71,8 @@ func MoveEntryAction(w http.ResponseWriter, r *http.Request) error { catIdx, _ := strconv.Atoi(r.URL.Query().Get("category")) tabIdx := TabFromRequest(r) pageIdx, _ := strconv.Atoi(r.URL.Query().Get("page")) - session := r.Context().Value(ContextValues("session")).(*sessions.Session) - githubUser, _ := session.Values["GithubUser"].(*User) + session := r.Context().Value(core.ContextValues("session")).(*sessions.Session) + githubUser, _ := session.Values["GithubUser"].(*core.BasicUser) token, _ := session.Values["Token"].(*oauth2.Token) login := "" if githubUser != nil { diff --git a/pageEditHandlers.go b/pageEditHandlers.go index fe62782..af3fa42 100644 --- a/pageEditHandlers.go +++ b/pageEditHandlers.go @@ -3,15 +3,17 @@ package gobookmarks import ( "context" "fmt" - "github.com/gorilla/sessions" - "golang.org/x/oauth2" "net/http" "strconv" + + "github.com/arran4/gobookmarks/core" + "github.com/gorilla/sessions" + "golang.org/x/oauth2" ) func EditPagePage(w http.ResponseWriter, r *http.Request) error { - session := r.Context().Value(ContextValues("session")).(*sessions.Session) - githubUser, _ := session.Values["GithubUser"].(*User) + session := r.Context().Value(core.ContextValues("session")).(*sessions.Session) + githubUser, _ := session.Values["GithubUser"].(*core.BasicUser) token, _ := session.Values["Token"].(*oauth2.Token) ref := r.URL.Query().Get("ref") tabIdx := TabFromRequest(r) @@ -28,13 +30,13 @@ func EditPagePage(w http.ResponseWriter, r *http.Request) error { } data := struct { - *CoreData + *core.CoreData Error string Name string Text string Sha string }{ - CoreData: r.Context().Value(ContextValues("coreData")).(*CoreData), + CoreData: r.Context().Value(core.ContextValues("coreData")).(*core.CoreData), Error: r.URL.Query().Get("error"), Name: r.URL.Query().Get("name"), Text: "", @@ -65,8 +67,8 @@ func PageEditSaveAction(w http.ResponseWriter, r *http.Request) error { tabIdx := TabFromRequest(r) pageIdx, pageErr := strconv.Atoi(r.PostFormValue("page")) - session := r.Context().Value(ContextValues("session")).(*sessions.Session) - githubUser, _ := session.Values["GithubUser"].(*User) + session := r.Context().Value(core.ContextValues("session")).(*sessions.Session) + githubUser, _ := session.Values["GithubUser"].(*core.BasicUser) token, _ := session.Values["Token"].(*oauth2.Token) login := "" @@ -96,8 +98,8 @@ func PageEditSaveAction(w http.ResponseWriter, r *http.Request) error { } else { newIndex := len(list[tabIdx].Pages) list[tabIdx].AddPage(p) - ctx := context.WithValue(r.Context(), ContextValues("redirectTab"), strconv.Itoa(tabIdx)) - ctx = context.WithValue(ctx, ContextValues("redirectPage"), strconv.Itoa(newIndex)) + ctx := context.WithValue(r.Context(), core.ContextValues("redirectTab"), strconv.Itoa(tabIdx)) + ctx = context.WithValue(ctx, core.ContextValues("redirectPage"), strconv.Itoa(newIndex)) *r = *r.WithContext(ctx) } diff --git a/provider.go b/provider.go index db0d464..d6293f0 100644 --- a/provider.go +++ b/provider.go @@ -2,38 +2,23 @@ package gobookmarks import ( "context" - "golang.org/x/oauth2" "sort" - "time" -) -type User struct { - Login string -} - -type Branch struct { - Name string -} + "github.com/arran4/gobookmarks/core" -type Tag struct { - Name string -} + "golang.org/x/oauth2" +) -type Commit struct { - SHA string - Message string - CommitterName string - CommitterEmail string - CommitterDate time.Time -} +// Types moved to core package, used directly now. type Provider interface { Name() string Config(clientID, clientSecret, redirectURL string) *oauth2.Config - CurrentUser(ctx context.Context, token *oauth2.Token) (*User, error) - GetTags(ctx context.Context, user string, token *oauth2.Token) ([]*Tag, error) - GetBranches(ctx context.Context, user string, token *oauth2.Token) ([]*Branch, error) - GetCommits(ctx context.Context, user string, token *oauth2.Token, ref string, page, perPage int) ([]*Commit, error) + // CurrentUser returns the currently authenticated user + CurrentUser(ctx context.Context, token *oauth2.Token) (core.User, error) + GetTags(ctx context.Context, user string, token *oauth2.Token) ([]*core.Tag, error) + GetBranches(ctx context.Context, user string, token *oauth2.Token) ([]*core.Branch, error) + GetCommits(ctx context.Context, user string, token *oauth2.Token, ref string, page, perPage int) ([]*core.Commit, error) GetBookmarks(ctx context.Context, user, ref string, token *oauth2.Token) (string, string, error) UpdateBookmarks(ctx context.Context, user string, token *oauth2.Token, sourceRef, branch, text, expectSHA string) error CreateBookmarks(ctx context.Context, user string, token *oauth2.Token, branch, text string) error diff --git a/provider_access.go b/provider_access.go index 685deb6..b20319d 100644 --- a/provider_access.go +++ b/provider_access.go @@ -8,34 +8,25 @@ import ( "time" "golang.org/x/oauth2" -) -type bookmarkCacheEntry struct { - bookmarks string - sha string - expiry time.Time -} + "github.com/arran4/gobookmarks/core" +) var bookmarksCache = struct { sync.RWMutex - data map[string]*bookmarkCacheEntry -}{data: make(map[string]*bookmarkCacheEntry)} + Data map[string]*core.BookmarkCacheEntry +}{Data: make(map[string]*core.BookmarkCacheEntry)} func cacheKey(user, ref string) string { return user + "|" + ref } -type requestCache struct { - sync.RWMutex - data map[string]*bookmarkCacheEntry -} - func invalidateRequestCache(ctx context.Context, user string) { - if cd, ok := ctx.Value(ContextValues("coreData")).(*CoreData); ok && cd.requestCache != nil { - cd.requestCache.Lock() - defer cd.requestCache.Unlock() + if cd, ok := ctx.Value(core.ContextValues("coreData")).(*core.CoreData); ok && cd.RequestCache != nil { + cd.RequestCache.Lock() + defer cd.RequestCache.Unlock() prefix := user + "|" - for k := range cd.requestCache.data { + for k := range cd.RequestCache.Data { if strings.HasPrefix(k, prefix) { - delete(cd.requestCache.data, k) + delete(cd.RequestCache.Data, k) } } } @@ -44,33 +35,33 @@ func invalidateRequestCache(ctx context.Context, user string) { func getCachedBookmarks(user, ref string) (string, string, bool) { key := cacheKey(user, ref) bookmarksCache.RLock() - entry, ok := bookmarksCache.data[key] + entry, ok := bookmarksCache.Data[key] bookmarksCache.RUnlock() - if !ok || time.Now().After(entry.expiry) { + if !ok || time.Now().After(entry.Expiry) { return "", "", false } - return entry.bookmarks, entry.sha, true + return entry.Bookmarks, entry.SHA, true } func setCachedBookmarks(user, ref, bookmarks, sha string) { key := cacheKey(user, ref) bookmarksCache.Lock() - bookmarksCache.data[key] = &bookmarkCacheEntry{bookmarks: bookmarks, sha: sha, expiry: time.Now().Add(time.Minute)} + bookmarksCache.Data[key] = &core.BookmarkCacheEntry{Bookmarks: bookmarks, SHA: sha, Expiry: time.Now().Add(time.Minute)} bookmarksCache.Unlock() } func invalidateBookmarkCache(user string) { bookmarksCache.Lock() - for k := range bookmarksCache.data { + for k := range bookmarksCache.Data { if strings.HasPrefix(k, user+"|") { - delete(bookmarksCache.data, k) + delete(bookmarksCache.Data, k) } } bookmarksCache.Unlock() } func providerFromContext(ctx context.Context) Provider { - if name, ok := ctx.Value(ContextValues("provider")).(string); ok { + if name, ok := ctx.Value(core.ContextValues("provider")).(string); ok { if p := GetProvider(name); p != nil { return p } @@ -110,7 +101,10 @@ func providerCreds(name string) *ProviderCreds { } } -func GetTags(ctx context.Context, user string, token *oauth2.Token) ([]*Tag, error) { +func GetTags(ctx context.Context, user string, token *oauth2.Token) ([]*core.Tag, error) { + if cd, ok := ctx.Value(core.ContextValues("coreData")).(*core.CoreData); ok && cd.Repo != nil { + return cd.Repo.GetTags(ctx, user, token) + } p := providerFromContext(ctx) if p == nil { return nil, ErrNoProvider @@ -122,7 +116,10 @@ func GetTags(ctx context.Context, user string, token *oauth2.Token) ([]*Tag, err return tags, err } -func GetBranches(ctx context.Context, user string, token *oauth2.Token) ([]*Branch, error) { +func GetBranches(ctx context.Context, user string, token *oauth2.Token) ([]*core.Branch, error) { + if cd, ok := ctx.Value(core.ContextValues("coreData")).(*core.CoreData); ok && cd.Repo != nil { + return cd.Repo.GetBranches(ctx, user, token) + } p := providerFromContext(ctx) if p == nil { return nil, ErrNoProvider @@ -134,7 +131,10 @@ func GetBranches(ctx context.Context, user string, token *oauth2.Token) ([]*Bran return bs, err } -func GetCommits(ctx context.Context, user string, token *oauth2.Token, ref string, page, perPage int) ([]*Commit, error) { +func GetCommits(ctx context.Context, user string, token *oauth2.Token, ref string, page, perPage int) ([]*core.Commit, error) { + if cd, ok := ctx.Value(core.ContextValues("coreData")).(*core.CoreData); ok && cd.Repo != nil { + return cd.Repo.GetCommits(ctx, user, token, ref, page, perPage) + } p := providerFromContext(ctx) if p == nil { return nil, ErrNoProvider @@ -147,6 +147,9 @@ func GetCommits(ctx context.Context, user string, token *oauth2.Token, ref strin } func GetAdjacentCommits(ctx context.Context, user string, token *oauth2.Token, ref, sha string) (string, string, error) { + if cd, ok := ctx.Value(core.ContextValues("coreData")).(*core.CoreData); ok && cd.Repo != nil { + return cd.Repo.AdjacentCommits(ctx, user, token, ref, sha) + } p := providerFromContext(ctx) if p == nil { return "", "", ErrNoProvider @@ -163,18 +166,27 @@ func GetAdjacentCommits(ctx context.Context, user string, token *oauth2.Token, r func GetBookmarks(ctx context.Context, user, ref string, token *oauth2.Token) (string, string, error) { key := cacheKey(user, ref) - if cd, ok := ctx.Value(ContextValues("coreData")).(*CoreData); ok && cd.requestCache != nil { - cd.requestCache.RLock() - if entry, ok := cd.requestCache.data[key]; ok { - cd.requestCache.RUnlock() - return entry.bookmarks, entry.sha, nil + if cd, ok := ctx.Value(core.ContextValues("coreData")).(*core.CoreData); ok && cd.RequestCache != nil { + cd.RequestCache.RLock() + if entry, ok := cd.RequestCache.Data[key]; ok { + cd.RequestCache.RUnlock() + return entry.Bookmarks, entry.SHA, nil } - cd.requestCache.RUnlock() + cd.RequestCache.RUnlock() } if b, sha, ok := getCachedBookmarks(user, ref); ok { return b, sha, nil } + if cd, ok := ctx.Value(core.ContextValues("coreData")).(*core.CoreData); ok && cd.Repo != nil { + b, sha, err := cd.Repo.GetBookmarks(ctx, user, ref, token) + if err == nil && cd.RequestCache != nil { + cd.RequestCache.Lock() + cd.RequestCache.Data[key] = &core.BookmarkCacheEntry{Bookmarks: b, SHA: sha} + cd.RequestCache.Unlock() + } + return b, sha, err + } p := providerFromContext(ctx) if p == nil { return "", "", ErrNoProvider @@ -184,16 +196,24 @@ func GetBookmarks(ctx context.Context, user, ref string, token *oauth2.Token) (s return "", "", ErrSignedOut } if err == nil { - if cd, ok := ctx.Value(ContextValues("coreData")).(*CoreData); ok && cd.requestCache != nil { - cd.requestCache.Lock() - cd.requestCache.data[key] = &bookmarkCacheEntry{bookmarks: b, sha: sha} - cd.requestCache.Unlock() + if cd, ok := ctx.Value(core.ContextValues("coreData")).(*core.CoreData); ok && cd.RequestCache != nil { + cd.RequestCache.Lock() + cd.RequestCache.Data[key] = &core.BookmarkCacheEntry{Bookmarks: b, SHA: sha} + cd.RequestCache.Unlock() } } return b, sha, err } func UpdateBookmarks(ctx context.Context, user string, token *oauth2.Token, sourceRef, branch, text, expectSHA string) error { + if cd, ok := ctx.Value(core.ContextValues("coreData")).(*core.CoreData); ok && cd.Repo != nil { + err := cd.Repo.UpdateBookmarks(ctx, user, token, sourceRef, branch, text, expectSHA) + if err == nil { + invalidateBookmarkCache(user) + invalidateRequestCache(ctx, user) + } + return err + } p := providerFromContext(ctx) if p == nil { return ErrNoProvider @@ -209,6 +229,14 @@ func UpdateBookmarks(ctx context.Context, user string, token *oauth2.Token, sour } func CreateBookmarks(ctx context.Context, user string, token *oauth2.Token, branch, text string) error { + if cd, ok := ctx.Value(core.ContextValues("coreData")).(*core.CoreData); ok && cd.Repo != nil { + err := cd.Repo.CreateBookmarks(ctx, user, token, branch, text) + if err == nil { + invalidateBookmarkCache(user) + invalidateRequestCache(ctx, user) + } + return err + } p := providerFromContext(ctx) if p == nil { return ErrNoProvider diff --git a/provider_git.go b/provider_git.go index 904679b..91f9758 100644 --- a/provider_git.go +++ b/provider_git.go @@ -13,6 +13,7 @@ import ( "golang.org/x/crypto/bcrypt" + "github.com/arran4/gobookmarks/core" git "github.com/go-git/go-git/v5" "github.com/go-git/go-git/v5/plumbing" "github.com/go-git/go-git/v5/plumbing/object" @@ -28,8 +29,8 @@ func init() { RegisterProvider(GitProvider{}) } func (GitProvider) Name() string { return "git" } func (GitProvider) DefaultServer() string { return "" } func (GitProvider) Config(clientID, clientSecret, redirectURL string) *oauth2.Config { return nil } -func (GitProvider) CurrentUser(ctx context.Context, token *oauth2.Token) (*User, error) { - return &User{Login: "local"}, nil +func (p GitProvider) CurrentUser(ctx context.Context, token *oauth2.Token) (core.User, error) { + return &core.BasicUser{Login: "local"}, nil } func userDir(user string) string { @@ -48,7 +49,7 @@ func openRepo(user string) (*git.Repository, error) { return r, nil } -func (GitProvider) GetTags(ctx context.Context, user string, token *oauth2.Token) ([]*Tag, error) { +func (GitProvider) GetTags(ctx context.Context, user string, token *oauth2.Token) ([]*core.Tag, error) { r, err := openRepo(user) if err != nil { if errors.Is(err, ErrRepoNotFound) { @@ -60,19 +61,19 @@ func (GitProvider) GetTags(ctx context.Context, user string, token *oauth2.Token if err != nil { return nil, err } - var tags []*Tag + var tags []*core.Tag err = iter.ForEach(func(ref *plumbing.Reference) error { - tags = append(tags, &Tag{Name: ref.Name().Short()}) + tags = append(tags, &core.Tag{Name: ref.Name().Short()}) return nil }) return tags, err } -func (GitProvider) GetBranches(ctx context.Context, user string, token *oauth2.Token) ([]*Branch, error) { +func (GitProvider) GetBranches(ctx context.Context, user string, token *oauth2.Token) ([]*core.Branch, error) { r, err := openRepo(user) if err != nil { if errors.Is(err, ErrRepoNotFound) { - return []*Branch{{Name: "main"}}, nil + return []*core.Branch{{Name: "main"}}, nil } return nil, err } @@ -80,18 +81,18 @@ func (GitProvider) GetBranches(ctx context.Context, user string, token *oauth2.T if err != nil { return nil, err } - var branches []*Branch + var branches []*core.Branch err = iter.ForEach(func(ref *plumbing.Reference) error { - branches = append(branches, &Branch{Name: ref.Name().Short()}) + branches = append(branches, &core.Branch{Name: ref.Name().Short()}) return nil }) if len(branches) == 0 { - branches = append(branches, &Branch{Name: "main"}) + branches = append(branches, &core.Branch{Name: "main"}) } return branches, err } -func (GitProvider) GetCommits(ctx context.Context, user string, token *oauth2.Token, ref string, page, perPage int) ([]*Commit, error) { +func (GitProvider) GetCommits(ctx context.Context, user string, token *oauth2.Token, ref string, page, perPage int) ([]*core.Commit, error) { r, err := openRepo(user) if err != nil { if errors.Is(err, ErrRepoNotFound) { @@ -112,7 +113,7 @@ func (GitProvider) GetCommits(ctx context.Context, user string, token *oauth2.To } start := (page - 1) * perPage i := 0 - var commits []*Commit + var commits []*core.Commit err = iter.ForEach(func(c *object.Commit) error { if i < start { i++ @@ -121,7 +122,7 @@ func (GitProvider) GetCommits(ctx context.Context, user string, token *oauth2.To if len(commits) >= perPage { return storer.ErrStop } - commits = append(commits, &Commit{ + commits = append(commits, &core.Commit{ SHA: c.Hash.String(), Message: c.Message, CommitterName: c.Committer.Name, diff --git a/provider_github.go b/provider_github.go index 00bef18..bb1d5c0 100644 --- a/provider_github.go +++ b/provider_github.go @@ -11,6 +11,7 @@ import ( "net/http" "strings" + "github.com/arran4/gobookmarks/core" "github.com/google/go-github/v55/github" "golang.org/x/oauth2" ) @@ -20,6 +21,7 @@ type GitHubProvider struct{} func init() { gob.Register(&github.User{}) + gob.Register(&core.BasicUser{}) RegisterProvider(GitHubProvider{}) } @@ -57,56 +59,59 @@ func (GitHubProvider) client(ctx context.Context, token *oauth2.Token) *github.C return c } -func (p GitHubProvider) CurrentUser(ctx context.Context, token *oauth2.Token) (*User, error) { +func (p GitHubProvider) CurrentUser(ctx context.Context, token *oauth2.Token) (core.User, error) { + if token == nil || !token.Valid() { + return nil, nil + } u, _, err := p.client(ctx, token).Users.Get(ctx, "") if err != nil { log.Printf("github CurrentUser: %v", err) return nil, err } - user := &User{} + user := &core.BasicUser{} if u.Login != nil { user.Login = *u.Login } return user, nil } -func (p GitHubProvider) GetTags(ctx context.Context, user string, token *oauth2.Token) ([]*Tag, error) { +func (p GitHubProvider) GetTags(ctx context.Context, user string, token *oauth2.Token) ([]*core.Tag, error) { tags, _, err := p.client(ctx, token).Repositories.ListTags(ctx, user, RepoName, &github.ListOptions{}) if err != nil { log.Printf("github GetTags: %v", err) return nil, fmt.Errorf("ListTags: %w", err) } - res := make([]*Tag, 0, len(tags)) + res := make([]*core.Tag, 0, len(tags)) for _, t := range tags { - res = append(res, &Tag{Name: t.GetName()}) + res = append(res, &core.Tag{Name: t.GetName()}) } return res, nil } -func (p GitHubProvider) GetBranches(ctx context.Context, user string, token *oauth2.Token) ([]*Branch, error) { +func (p GitHubProvider) GetBranches(ctx context.Context, user string, token *oauth2.Token) ([]*core.Branch, error) { bs, _, err := p.client(ctx, token).Repositories.ListBranches(ctx, user, RepoName, &github.BranchListOptions{}) if err != nil { log.Printf("github GetBranches: %v", err) return nil, fmt.Errorf("ListBranches: %w", err) } - res := make([]*Branch, 0, len(bs)) + res := make([]*core.Branch, 0, len(bs)) for _, b := range bs { - res = append(res, &Branch{Name: b.GetName()}) + res = append(res, &core.Branch{Name: b.GetName()}) } return res, nil } -func (p GitHubProvider) GetCommits(ctx context.Context, user string, token *oauth2.Token, ref string, page, perPage int) ([]*Commit, error) { +func (p GitHubProvider) GetCommits(ctx context.Context, user string, token *oauth2.Token, ref string, page, perPage int) ([]*core.Commit, error) { opts := &github.CommitsListOptions{SHA: ref, ListOptions: github.ListOptions{Page: page, PerPage: perPage}} cs, _, err := p.client(ctx, token).Repositories.ListCommits(ctx, user, RepoName, opts) if err != nil { log.Printf("github GetCommits: %v", err) return nil, fmt.Errorf("ListCommits: %w", err) } - res := make([]*Commit, 0, len(cs)) + res := make([]*core.Commit, 0, len(cs)) for _, c := range cs { cm := c.GetCommit() - com := &Commit{SHA: c.GetSHA()} + com := &core.Commit{SHA: c.GetSHA()} if cm != nil { com.Message = cm.GetMessage() comm := cm.Committer diff --git a/provider_gitlab.go b/provider_gitlab.go index 346a053..e30cd0f 100644 --- a/provider_gitlab.go +++ b/provider_gitlab.go @@ -12,6 +12,7 @@ import ( "net/http" "strings" + "github.com/arran4/gobookmarks/core" gitlab "github.com/xanzy/go-gitlab" "golang.org/x/oauth2" ) @@ -29,6 +30,7 @@ func gitlabUnauthorized(err error) bool { func init() { gob.Register(&gitlab.User{}) + gob.Register(&core.BasicUser{}) RegisterProvider(GitLabProvider{}) } @@ -61,8 +63,11 @@ func (GitLabProvider) client(token *oauth2.Token) (*gitlab.Client, error) { return gitlab.NewOAuthClient(token.AccessToken, gitlab.WithBaseURL(server)) } -func (GitLabProvider) CurrentUser(ctx context.Context, token *oauth2.Token) (*User, error) { - c, err := GitLabProvider{}.client(token) +func (p GitLabProvider) CurrentUser(ctx context.Context, token *oauth2.Token) (core.User, error) { + if token == nil || !token.Valid() { + return nil, nil + } + c, err := p.client(token) if err != nil { log.Printf("gitlab CurrentUser client: %v", err) return nil, err @@ -72,10 +77,10 @@ func (GitLabProvider) CurrentUser(ctx context.Context, token *oauth2.Token) (*Us log.Printf("gitlab CurrentUser lookup: %v", err) return nil, err } - return &User{Login: u.Username}, nil + return &core.BasicUser{Login: u.Username}, nil } -func (GitLabProvider) GetTags(ctx context.Context, user string, token *oauth2.Token) ([]*Tag, error) { +func (GitLabProvider) GetTags(ctx context.Context, user string, token *oauth2.Token) ([]*core.Tag, error) { c, err := GitLabProvider{}.client(token) if err != nil { log.Printf("gitlab GetTags client: %v", err) @@ -89,14 +94,14 @@ func (GitLabProvider) GetTags(ctx context.Context, user string, token *oauth2.To log.Printf("gitlab GetTags: %v", err) return nil, fmt.Errorf("ListTags: %w", err) } - res := make([]*Tag, 0, len(tags)) + res := make([]*core.Tag, 0, len(tags)) for _, t := range tags { - res = append(res, &Tag{Name: t.Name}) + res = append(res, &core.Tag{Name: t.Name}) } return res, nil } -func (GitLabProvider) GetBranches(ctx context.Context, user string, token *oauth2.Token) ([]*Branch, error) { +func (GitLabProvider) GetBranches(ctx context.Context, user string, token *oauth2.Token) ([]*core.Branch, error) { c, err := GitLabProvider{}.client(token) if err != nil { log.Printf("gitlab GetBranches client: %v", err) @@ -110,14 +115,14 @@ func (GitLabProvider) GetBranches(ctx context.Context, user string, token *oauth log.Printf("gitlab GetBranches: %v", err) return nil, fmt.Errorf("ListBranches: %w", err) } - res := make([]*Branch, 0, len(bs)) + res := make([]*core.Branch, 0, len(bs)) for _, b := range bs { - res = append(res, &Branch{Name: b.Name}) + res = append(res, &core.Branch{Name: b.Name}) } return res, nil } -func (GitLabProvider) GetCommits(ctx context.Context, user string, token *oauth2.Token, ref string, page, perPage int) ([]*Commit, error) { +func (GitLabProvider) GetCommits(ctx context.Context, user string, token *oauth2.Token, ref string, page, perPage int) ([]*core.Commit, error) { c, err := GitLabProvider{}.client(token) if err != nil { log.Printf("gitlab GetCommits client: %v", err) @@ -131,9 +136,9 @@ func (GitLabProvider) GetCommits(ctx context.Context, user string, token *oauth2 log.Printf("gitlab GetCommits: %v", err) return nil, fmt.Errorf("ListCommits: %w", err) } - res := make([]*Commit, 0, len(cs)) + res := make([]*core.Commit, 0, len(cs)) for _, commit := range cs { - res = append(res, &Commit{ + res = append(res, &core.Commit{ SHA: commit.ID, Message: commit.Message, CommitterName: commit.CommitterName, diff --git a/provider_sql.go b/provider_sql.go index c2bddd6..4a07757 100644 --- a/provider_sql.go +++ b/provider_sql.go @@ -12,6 +12,7 @@ import ( "sync" "time" + "github.com/arran4/gobookmarks/core" _ "github.com/go-sql-driver/mysql" _ "github.com/mattn/go-sqlite3" "golang.org/x/crypto/bcrypt" @@ -32,6 +33,12 @@ func init() { RegisterProvider(&SQLProvider{}) } +// NewSQLProvider creates a new SQLProvider with the given database connection. +// This is used for integrating with existing database pools (e.g. goa4web). +func NewSQLProvider(db *sql.DB) *SQLProvider { + return &SQLProvider{db: db} +} + func (p *SQLProvider) getDB() (*sql.DB, error) { p.mu.Lock() defer p.mu.Unlock() @@ -51,11 +58,11 @@ func (p *SQLProvider) getDB() (*sql.DB, error) { func (p *SQLProvider) Name() string { return "sql" } func (p *SQLProvider) DefaultServer() string { return "" } func (p *SQLProvider) Config(clientID, clientSecret, redirectURL string) *oauth2.Config { return nil } -func (p *SQLProvider) CurrentUser(ctx context.Context, token *oauth2.Token) (*User, error) { - return nil, errors.New("not implemented") +func (p *SQLProvider) CurrentUser(ctx context.Context, token *oauth2.Token) (core.User, error) { + return nil, nil } -func (p *SQLProvider) GetTags(ctx context.Context, user string, token *oauth2.Token) ([]*Tag, error) { +func (p *SQLProvider) GetTags(ctx context.Context, user string, token *oauth2.Token) ([]*core.Tag, error) { db, err := p.getDB() if err != nil { return nil, err @@ -67,18 +74,18 @@ func (p *SQLProvider) GetTags(ctx context.Context, user string, token *oauth2.To } defer rows.Close() - var tags []*Tag + var tags []*core.Tag for rows.Next() { var n string if err := rows.Scan(&n); err != nil { return nil, err } - tags = append(tags, &Tag{Name: n}) + tags = append(tags, &core.Tag{Name: n}) } return tags, rows.Err() } -func (p *SQLProvider) GetBranches(ctx context.Context, user string, token *oauth2.Token) ([]*Branch, error) { +func (p *SQLProvider) GetBranches(ctx context.Context, user string, token *oauth2.Token) ([]*core.Branch, error) { db, err := p.getDB() if err != nil { return nil, err @@ -90,21 +97,21 @@ func (p *SQLProvider) GetBranches(ctx context.Context, user string, token *oauth } defer rows.Close() - var branches []*Branch + var branches []*core.Branch for rows.Next() { var n string if err := rows.Scan(&n); err != nil { return nil, err } - branches = append(branches, &Branch{Name: n}) + branches = append(branches, &core.Branch{Name: n}) } if len(branches) == 0 { - branches = append(branches, &Branch{Name: "main"}) + branches = append(branches, &core.Branch{Name: "main"}) } return branches, rows.Err() } -func (p *SQLProvider) GetCommits(ctx context.Context, user string, token *oauth2.Token, ref string, page, perPage int) ([]*Commit, error) { +func (p *SQLProvider) GetCommits(ctx context.Context, user string, token *oauth2.Token, ref string, page, perPage int) ([]*core.Commit, error) { db, err := p.getDB() if err != nil { return nil, err @@ -122,14 +129,14 @@ func (p *SQLProvider) GetCommits(ctx context.Context, user string, token *oauth2 } defer rows.Close() - var commits []*Commit + var commits []*core.Commit for rows.Next() { var sha, msg string var t time.Time if err := rows.Scan(&sha, &msg, &t); err != nil { return nil, err } - commits = append(commits, &Commit{ + commits = append(commits, &core.Commit{ SHA: sha, Message: msg, CommitterName: "gobookmarks", diff --git a/router.go b/router.go index 30e36bf..49c2b89 100644 --- a/router.go +++ b/router.go @@ -1,57 +1,46 @@ package gobookmarks import ( - "database/sql" + "context" "net/http" + "github.com/arran4/gobookmarks/app" + "github.com/arran4/gobookmarks/core" "github.com/arran4/gorillamuxlogic" "github.com/gorilla/mux" - "github.com/gorilla/sessions" ) -// RouterConfig holds the dependencies and configuration for the gobookmarks application -type RouterConfig struct { - // DB is the database connection - DB *sql.DB - - // UserProvider handles user authentication and lookup - UserProvider UserProvider - - // SessionStore handles session management - SessionStore sessions.Store - // SessionName is the name of the session cookie - SessionName string - - // BaseURL is the prefix for all routes (e.g. "/bookmarks") - BaseURL string - - // ExternalURL is the public facing URL - ExternalURL string - - // DevMode enables development features - DevMode bool -} - -// UserProvider defines the interface for external user management -type UserProvider interface { - // CurrentUser returns the current user from the request context - CurrentUser(r *http.Request) (*User, error) - // IsLoggedIn checks if a user is logged in - IsLoggedIn(r *http.Request) bool -} - -// NewRouter creates a new router with the given configuration -func NewRouter(cfg *RouterConfig) http.Handler { +// NewRouter creates a new router with the given application dependencies +func NewRouter(a *app.App) http.Handler { // Initialize globals temporarily until full refactor - if cfg.ExternalURL != "" { + // Initialize globals from config + if a.Config.ExternalURL != "" { // ExternalUrl is a global in gobookmarks package, need to verify } + GithubClientID = a.Config.GithubClientID + GithubClientSecret = a.Config.GithubSecret + GithubServer = a.Config.GithubServer + GitlabClientID = a.Config.GitlabClientID + GitlabClientSecret = a.Config.GitlabSecret + GitlabServer = a.Config.GitlabServer + DevMode = a.Config.DevMode + + UseCssColumns = a.Config.CssColumns + SiteTitle = a.Config.Title + NoFooter = a.Config.NoFooter + LocalGitPath = a.Config.LocalGitPath + CommitsPerPage = a.Config.CommitsPerPage + FaviconCacheDir = a.Config.FaviconCacheDir + FaviconCacheSize = a.Config.FaviconCacheSize + FaviconMaxCacheCount = a.Config.FaviconMaxCacheCount + if a.Config.ExternalURL != "" { + OauthRedirectURL = a.Config.ExternalURL + "/oauth2Callback" + } // Note: We are not handling all globals yet, focusing on router structure first r := mux.NewRouter() - r.Use(UserAdderMiddleware) // Middleware needs to be adapted to use UserProvider if present - r.Use(CoreAdderMiddleware) + r.Use(CoreDataMiddleware(a)) r.HandleFunc("/main.css", func(writer http.ResponseWriter, request *http.Request) { _, _ = writer.Write(GetMainCSSData()) @@ -61,7 +50,7 @@ func NewRouter(cfg *RouterConfig) http.Handler { }).Methods("GET") // Development helpers to toggle layout mode - if cfg.DevMode { + if a.Config.DevMode { r.HandleFunc("/_css", runHandlerChain(EnableCssColumnsAction, redirectToHandler("/"))).Methods("GET") r.HandleFunc("/_table", runHandlerChain(DisableCssColumnsAction, redirectToHandler("/"))).Methods("GET") } @@ -153,3 +142,46 @@ func NewRouter(cfg *RouterConfig) http.Handler { return r } + +func CoreDataMiddleware(a *app.App) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + session, _ := a.SessionStore.Get(r, a.Config.SessionName) // Handle error? + // ... (Existing logic adaptation) + + // For CoreData construction: + cd := &core.CoreData{ + Title: SiteTitle, // Use config or global + AutoRefresh: false, // Default + EditMode: r.URL.Query().Get("edit") == "1", + Tab: TabFromRequest(r), + Repo: a.Repo, + Session: session, + // User: ... populated below or separate middleware + } + + // User logic + if a.UserProvider != nil { + if u, err := a.UserProvider.CurrentUser(r); err == nil && u != nil { + cd.User = u + cd.UserRef = u.GetLogin() + } + } else { + // Fallback to legacy session lookup (for standalone if Provider not set) + if session != nil { + if u, ok := session.Values["GithubUser"].(*core.BasicUser); ok { + cd.User = u + cd.UserRef = u.GetLogin() + } + } + } + + // Inject CoreData + ctx := context.WithValue(r.Context(), core.ContextValues("coreData"), cd) + // Also "session" for legacy handlers if needed + ctx = context.WithValue(ctx, core.ContextValues("session"), session) + + next.ServeHTTP(w, r.WithContext(ctx)) + }) + } +} diff --git a/router_helpers.go b/router_helpers.go index 713aec9..bc73c7d 100644 --- a/router_helpers.go +++ b/router_helpers.go @@ -12,7 +12,6 @@ import ( "strconv" "github.com/gorilla/mux" - "github.com/gorilla/sessions" "crypto/ecdsa" "crypto/elliptic" @@ -22,6 +21,8 @@ import ( "encoding/pem" "math/big" "time" + + "github.com/arran4/gobookmarks/core" ) func runHandlerChain(chain ...any) func(http.ResponseWriter, *http.Request) { @@ -43,8 +44,8 @@ func runHandlerChain(chain ...any) func(http.ResponseWriter, *http.Request) { if logoutErr := UserLogoutAction(w, r); logoutErr != nil { log.Printf("logout error: %v", logoutErr) } - type Data struct{ *CoreData } - if err := GetCompiledTemplates(NewFuncs(r)).ExecuteTemplate(w, "logoutPage.gohtml", Data{r.Context().Value(ContextValues("coreData")).(*CoreData)}); err != nil { + type Data struct{ *core.CoreData } + if err := GetCompiledTemplates(NewFuncs(r)).ExecuteTemplate(w, "logoutPage.gohtml", Data{r.Context().Value(core.ContextValues("coreData")).(*core.CoreData)}); err != nil { log.Printf("Logout Template Error: %s", err) http.Error(w, "Internal Server Error", http.StatusInternalServerError) } @@ -82,11 +83,11 @@ func runHandlerChain(chain ...any) func(http.ResponseWriter, *http.Request) { log.Printf("handler error: %v", err) type ErrorData struct { - *CoreData + *core.CoreData Error string } if err := GetCompiledTemplates(NewFuncs(r)).ExecuteTemplate(w, "error.gohtml", ErrorData{ - CoreData: r.Context().Value(ContextValues("coreData")).(*CoreData), + CoreData: r.Context().Value(core.ContextValues("coreData")).(*core.CoreData), Error: display, }); err != nil { log.Printf("Error Template Error: %s", err) @@ -104,12 +105,12 @@ func runHandlerChain(chain ...any) func(http.ResponseWriter, *http.Request) { func runTemplate(tmpl string) func(http.ResponseWriter, *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { type Data struct { - *CoreData + *core.CoreData Error string } data := Data{ - CoreData: r.Context().Value(ContextValues("coreData")).(*CoreData), + CoreData: r.Context().Value(core.ContextValues("coreData")).(*core.CoreData), Error: r.URL.Query().Get("error"), } @@ -124,7 +125,7 @@ func runTemplate(tmpl string) func(http.ResponseWriter, *http.Request) { if logoutErr := UserLogoutAction(w, r); logoutErr != nil { log.Printf("logout error: %v", logoutErr) } - type LogoutData struct{ *CoreData } + type LogoutData struct{ *core.CoreData } if tplErr := GetCompiledTemplates(NewFuncs(r)).ExecuteTemplate(w, "logoutPage.gohtml", LogoutData{data.CoreData}); tplErr != nil { log.Printf("Logout Template Error: %v", tplErr) http.Error(w, "Internal Server Error", http.StatusInternalServerError) @@ -142,7 +143,7 @@ func runTemplate(tmpl string) func(http.ResponseWriter, *http.Request) { log.Printf("Template %s error: %v", tmpl, err) type ErrorData struct { - *CoreData + *core.CoreData Error string } @@ -168,14 +169,14 @@ func redirectToHandlerBranchToRef(toUrl string) func(http.ResponseWriter, *http. qs := u.Query() qs.Set("ref", "refs/heads/"+r.PostFormValue("branch")) tab := TabFromRequest(r) - if v, ok := r.Context().Value(ContextValues("redirectTab")).(string); ok { + if v, ok := r.Context().Value(core.ContextValues("redirectTab")).(string); ok { if parsed, err := strconv.Atoi(v); err == nil { tab = parsed } } u.Path = TabPath(tab) page := r.PostFormValue("page") - if v, ok := r.Context().Value(ContextValues("redirectPage")).(string); ok { + if v, ok := r.Context().Value(core.ContextValues("redirectPage")).(string); ok { page = v } if fragment := PageFragmentFromIndex(page); fragment != "" { @@ -207,25 +208,21 @@ func redirectToHandlerTabPage(toUrl string) func(http.ResponseWriter, *http.Requ func RequiresAnAccount() mux.MatcherFunc { return func(request *http.Request, match *mux.RouteMatch) bool { - var session *sessions.Session - sessioni := request.Context().Value(ContextValues("session")) - if sessioni == nil { - var err error - session, err = SessionStore.Get(request, SessionName) - if err != nil { - return false - } - } else { - var ok bool - session, ok = sessioni.(*sessions.Session) - if !ok { - return false - } + if DevMode { + return true + } + cc := GetCore(request.Context()) + if cc == nil { + return false + } + session := cc.GetSession() + if session == nil { + return false } if v, ok := session.Values["version"].(string); !ok || v != version { return false } - githubUser, ok := session.Values["GithubUser"].(*User) + githubUser, ok := session.Values["GithubUser"].(*core.BasicUser) return ok && githubUser != nil } } diff --git a/router_helpers_test.go b/router_helpers_test.go index 8788301..161cc4a 100644 --- a/router_helpers_test.go +++ b/router_helpers_test.go @@ -8,6 +8,7 @@ import ( "strings" "testing" + "github.com/arran4/gobookmarks/core" "github.com/gorilla/sessions" "golang.org/x/oauth2" ) @@ -18,7 +19,7 @@ func TestRunHandlerChain_UserErrorRedirect(t *testing.T) { req := httptest.NewRequest("GET", "/submit", nil) req.Header.Set("Referer", "/form") - ctx := context.WithValue(req.Context(), ContextValues("coreData"), &CoreData{}) + ctx := context.WithValue(req.Context(), core.ContextValues("coreData"), &core.CoreData{}) req = req.WithContext(ctx) h := runHandlerChain(func(w http.ResponseWriter, r *http.Request) error { @@ -44,11 +45,11 @@ func TestRunTemplate_BufferedError(t *testing.T) { req := httptest.NewRequest("GET", "/", nil) sess, _ := SessionStore.New(req, SessionName) - sess.Values["GithubUser"] = &User{Login: "user"} + sess.Values["GithubUser"] = &core.BasicUser{Login: "user"} sess.Values["Token"] = &oauth2.Token{} - ctx := context.WithValue(req.Context(), ContextValues("session"), sess) - ctx = context.WithValue(ctx, ContextValues("provider"), "sql") - ctx = context.WithValue(ctx, ContextValues("coreData"), &CoreData{UserRef: "user"}) + ctx := context.WithValue(req.Context(), core.ContextValues("session"), sess) + ctx = context.WithValue(ctx, core.ContextValues("provider"), "sql") + ctx = context.WithValue(ctx, core.ContextValues("coreData"), &core.CoreData{UserRef: "user"}) req = req.WithContext(ctx) w := httptest.NewRecorder() diff --git a/session_test.go b/session_test.go index f5bb26e..d87d739 100644 --- a/session_test.go +++ b/session_test.go @@ -1,9 +1,11 @@ package gobookmarks import ( - "github.com/gorilla/sessions" "net/http/httptest" "testing" + + "github.com/arran4/gobookmarks/core" + "github.com/gorilla/sessions" ) // Test that getSession clears outdated sessions and returns a fresh one. @@ -18,7 +20,7 @@ func Test_getSessionClearsOldVersion(t *testing.T) { // create a session with an old version s, _ := SessionStore.New(req, SessionName) s.Values["version"] = "old" - s.Values["GithubUser"] = &User{Login: "old"} + s.Values["GithubUser"] = &core.BasicUser{Login: "old"} if err := s.Save(req, w); err != nil { t.Fatalf("save old session: %v", err) } diff --git a/signup_git_http_test.go b/signup_git_http_test.go index 955ed2c..acdd0c7 100644 --- a/signup_git_http_test.go +++ b/signup_git_http_test.go @@ -9,6 +9,7 @@ import ( "strings" "testing" + "github.com/arran4/gobookmarks/core" "github.com/gorilla/sessions" ) @@ -62,9 +63,9 @@ func TestGitSignupScenario(t *testing.T) { if err != nil { t.Fatalf("getSession: %v", err) } - ctx := context.WithValue(sessReq.Context(), ContextValues("session"), session) - ctx = context.WithValue(ctx, ContextValues("provider"), "git") - ctx = context.WithValue(ctx, ContextValues("coreData"), &CoreData{}) + ctx := context.WithValue(sessReq.Context(), core.ContextValues("session"), session) + ctx = context.WithValue(ctx, core.ContextValues("provider"), "git") + ctx = context.WithValue(ctx, core.ContextValues("coreData"), &core.CoreData{Session: session}) // create bookmarks on new branch createText := "Category: New\nhttp://example.com new" diff --git a/tabEditHandlers.go b/tabEditHandlers.go index ba8f565..6f3c27f 100644 --- a/tabEditHandlers.go +++ b/tabEditHandlers.go @@ -3,18 +3,19 @@ package gobookmarks import ( "context" "fmt" - "github.com/gorilla/sessions" - "golang.org/x/oauth2" "net/http" "strconv" "strings" + + "github.com/arran4/gobookmarks/core" + "golang.org/x/oauth2" ) func EditTabPage(w http.ResponseWriter, r *http.Request) error { tabName := r.URL.Query().Get("name") tabIdx := TabFromRequest(r) - session := r.Context().Value(ContextValues("session")).(*sessions.Session) - githubUser, _ := session.Values["GithubUser"].(*User) + session := GetCore(r.Context()).GetSession() + githubUser, _ := session.Values["GithubUser"].(*core.BasicUser) token, _ := session.Values["Token"].(*oauth2.Token) ref := r.URL.Query().Get("ref") @@ -52,14 +53,14 @@ func EditTabPage(w http.ResponseWriter, r *http.Request) error { } data := struct { - *CoreData + *core.CoreData Error string Name string OldName string Text string Sha string }{ - CoreData: r.Context().Value(ContextValues("coreData")).(*CoreData), + CoreData: r.Context().Value(core.ContextValues("coreData")).(*core.CoreData), Error: r.URL.Query().Get("error"), Name: tabName, OldName: tabName, @@ -82,8 +83,8 @@ func TabEditSaveAction(w http.ResponseWriter, r *http.Request) error { tabIdx := TabFromRequest(r) sha := r.PostFormValue("sha") - session := r.Context().Value(ContextValues("session")).(*sessions.Session) - githubUser, _ := session.Values["GithubUser"].(*User) + session := GetCore(r.Context()).GetSession() + githubUser, _ := session.Values["GithubUser"].(*core.BasicUser) token, _ := session.Values["Token"].(*oauth2.Token) login := "" @@ -119,8 +120,8 @@ func TabEditSaveAction(w http.ResponseWriter, r *http.Request) error { return fmt.Errorf("updateBookmark error: %w", err) } if oldName == "" { - ctx := context.WithValue(r.Context(), ContextValues("redirectTab"), strconv.Itoa(newIndex)) - ctx = context.WithValue(ctx, ContextValues("redirectPage"), "0") + ctx := context.WithValue(r.Context(), core.ContextValues("redirectTab"), strconv.Itoa(newIndex)) + ctx = context.WithValue(ctx, core.ContextValues("redirectPage"), "0") *r = *r.WithContext(ctx) } return nil