diff --git a/go.mod b/go.mod index 05720f0..623fdf2 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,5 @@ module github.com/graphql-go/handler go 1.14 + +require github.com/graphql-go/graphql v0.8.1 // indirect diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..5f20647 --- /dev/null +++ b/go.sum @@ -0,0 +1,2 @@ +github.com/graphql-go/graphql v0.8.1 h1:p7/Ou/WpmulocJeEx7wjQy611rtXGQaAcXGqanuMMgc= +github.com/graphql-go/graphql v0.8.1/go.mod h1:nKiHzRM0qopJEwCITUuIsxk9PlVlwIiiI8pnJEhordQ= diff --git a/graphcoolPlayground.go b/graphcoolPlayground.go index 43be574..02abaeb 100644 --- a/graphcoolPlayground.go +++ b/graphcoolPlayground.go @@ -1,7 +1,6 @@ package handler import ( - "fmt" "html/template" "net/http" ) @@ -14,7 +13,7 @@ type playgroundData struct { } // renderPlayground renders the Playground GUI -func renderPlayground(w http.ResponseWriter, r *http.Request) { +func renderPlayground(w http.ResponseWriter, r *http.Request, endpoint string, subscriptionEndpoint string) { t := template.New("Playground") t, err := t.Parse(graphcoolPlaygroundTemplate) if err != nil { @@ -24,16 +23,14 @@ func renderPlayground(w http.ResponseWriter, r *http.Request) { d := playgroundData{ PlaygroundVersion: graphcoolPlaygroundVersion, - Endpoint: r.URL.Path, - SubscriptionEndpoint: fmt.Sprintf("ws://%v/subscriptions", r.Host), + Endpoint: endpoint, + SubscriptionEndpoint: subscriptionEndpoint, SetTitle: true, } err = t.ExecuteTemplate(w, "index", d) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) } - - return } const graphcoolPlaygroundVersion = "1.5.2" diff --git a/handler.go b/handler.go index b9a647c..69f150b 100644 --- a/handler.go +++ b/handler.go @@ -2,6 +2,7 @@ package handler import ( "encoding/json" + "fmt" "io/ioutil" "net/http" "net/url" @@ -27,6 +28,7 @@ type Handler struct { pretty bool graphiql bool playground bool + playgroundConfig *PlaygroundConfig rootObjectFn RootObjectFn resultCallbackFn ResultCallbackFn formatErrorFn func(err error) gqlerrors.FormattedError @@ -162,7 +164,15 @@ func (h *Handler) ContextHandler(ctx context.Context, w http.ResponseWriter, r * acceptHeader := r.Header.Get("Accept") _, raw := r.URL.Query()["raw"] if !raw && !strings.Contains(acceptHeader, "application/json") && strings.Contains(acceptHeader, "text/html") { - renderPlayground(w, r) + + endpoint := r.URL.Path + subscriptionEndpoint := fmt.Sprintf("ws://%v/subscriptions", r.Host) + if h.playgroundConfig != nil { + endpoint = h.playgroundConfig.Endpoint + subscriptionEndpoint = h.playgroundConfig.SubscriptionEndpoint + } + + renderPlayground(w, r, endpoint, subscriptionEndpoint) return } } @@ -196,11 +206,17 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // RootObjectFn allows a user to generate a RootObject per request type RootObjectFn func(ctx context.Context, r *http.Request) map[string]interface{} +type PlaygroundConfig struct { + Endpoint string + SubscriptionEndpoint string +} + type Config struct { Schema *graphql.Schema Pretty bool GraphiQL bool Playground bool + PlaygroundConfig *PlaygroundConfig RootObjectFn RootObjectFn ResultCallbackFn ResultCallbackFn FormatErrorFn func(err error) gqlerrors.FormattedError @@ -208,10 +224,11 @@ type Config struct { func NewConfig() *Config { return &Config{ - Schema: nil, - Pretty: true, - GraphiQL: true, - Playground: false, + Schema: nil, + Pretty: true, + GraphiQL: true, + Playground: false, + PlaygroundConfig: nil, } } @@ -229,6 +246,7 @@ func New(p *Config) *Handler { pretty: p.Pretty, graphiql: p.GraphiQL, playground: p.Playground, + playgroundConfig: p.PlaygroundConfig, rootObjectFn: p.RootObjectFn, resultCallbackFn: p.ResultCallbackFn, formatErrorFn: p.FormatErrorFn, diff --git a/handler_test.go b/handler_test.go index 4154b1a..39d8e20 100644 --- a/handler_test.go +++ b/handler_test.go @@ -290,3 +290,113 @@ func TestHandler_BasicQuery_WithFormatErrorFn(t *testing.T) { t.Fatalf("wrong result, graphql result diff: %v", testutil.Diff(expected, result)) } } + +func TestPlaygroundWithDefaultConfig(t *testing.T) { + query := graphql.NewObject(graphql.ObjectConfig{ + Name: "Query", + Fields: graphql.Fields{ + "ping": &graphql.Field{ + Name: "ping", + Type: graphql.String, + Resolve: func(p graphql.ResolveParams) (interface{}, error) { + return "OK", nil + }, + }, + }, + }) + + schema, err := graphql.NewSchema(graphql.SchemaConfig{ + Query: query, + }) + if err != nil { + t.Fatal(err) + } + + req, err := http.NewRequest("GET", "/graphql", nil) + req.Header.Set("Accept", "text/html") + if err != nil { + t.Fatal(err) + } + + h := handler.New(&handler.Config{ + Schema: &schema, + Playground: true, + }) + + resp := httptest.NewRecorder() + h.ContextHandler(context.Background(), resp, req) + + if resp.Code != http.StatusOK { + t.Fatalf("unexpected server response %v", resp.Code) + } + + expectedBodyContains := []string{ + "GraphQL Playground", + `endpoint: "/graphql"`, + `subscriptionEndpoint: "ws:///subscriptions"`, + } + respBody := resp.Body.String() + + for _, e := range expectedBodyContains { + if !strings.Contains(respBody, e) { + t.Fatalf("wrong body, expected %s to contain %s", respBody, e) + } + } +} + +func TestPlaygroundWithCustomConfig(t *testing.T) { + query := graphql.NewObject(graphql.ObjectConfig{ + Name: "Query", + Fields: graphql.Fields{ + "ping": &graphql.Field{ + Name: "ping", + Type: graphql.String, + Resolve: func(p graphql.ResolveParams) (interface{}, error) { + return "OK", nil + }, + }, + }, + }) + + schema, err := graphql.NewSchema(graphql.SchemaConfig{ + Query: query, + }) + if err != nil { + t.Fatal(err) + } + + req, err := http.NewRequest("GET", "/custom-path/graphql", nil) + req.Header.Set("Accept", "text/html") + if err != nil { + t.Fatal(err) + } + + h := handler.New(&handler.Config{ + Schema: &schema, + Playground: true, + PlaygroundConfig: &handler.PlaygroundConfig{ + Endpoint: "/custom-path/graphql", + SubscriptionEndpoint: "/custom-path/ws", + }, + }) + + resp := httptest.NewRecorder() + h.ContextHandler(context.Background(), resp, req) + + if resp.Code != http.StatusOK { + t.Fatalf("unexpected server response %v", resp.Code) + } + + expectedBodyContains := []string{ + "GraphQL Playground", + `endpoint: "/custom-path/graphql"`, + `subscriptionEndpoint: "/custom-path/ws"`, + } + respBody := resp.Body.String() + + for _, e := range expectedBodyContains { + if !strings.Contains(respBody, e) { + t.Fatalf("wrong body, expected %s to contain %s", respBody, e) + } + } +}