From 548a78d9ac89312c2060f3a8fc285a991cd9dd5d Mon Sep 17 00:00:00 2001 From: "Chang, Hui-Tang" Date: Sat, 13 Jan 2024 07:07:36 -0600 Subject: [PATCH] feat(instill): use grpc client for all request (#108) Because - we want to use the same connection approach for getting models, triggering and testing connection. This commit - use grpc client for getting models and testing connection - remove the unit test for `Test()` --- pkg/instill/connector_test.go | 83 -------------------------- pkg/instill/main.go | 109 ++++++++++++++++------------------ 2 files changed, 52 insertions(+), 140 deletions(-) delete mode 100644 pkg/instill/connector_test.go diff --git a/pkg/instill/connector_test.go b/pkg/instill/connector_test.go deleted file mode 100644 index ab3fc30..0000000 --- a/pkg/instill/connector_test.go +++ /dev/null @@ -1,83 +0,0 @@ -package instill - -import ( - "net/http" - "net/http/httptest" - "testing" - - qt "github.com/frankban/quicktest" - "github.com/gofrs/uuid" - "go.uber.org/zap" - "google.golang.org/protobuf/types/known/structpb" - - pipelinePB "github.com/instill-ai/protogen-go/vdp/pipeline/v1beta" - "github.com/instill-ai/x/errmsg" -) - -const ( - apiKey = "123" -) - -func TestConnector_Test(t *testing.T) { - c := qt.New(t) - - logger := zap.NewNop() - connector := Init(logger) - userID, defID := uuid.Must(uuid.NewV4()), uuid.Must(uuid.NewV4()) - - wantPath := "/model/v1alpha/models" - c.Run("nok - error", func(c *qt.C) { - h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - c.Check(r.Method, qt.Equals, http.MethodGet) - c.Check(r.URL.Path, qt.Equals, wantPath) - - c.Check(r.Header.Get("Authorization"), qt.Equals, "Bearer "+apiKey) - c.Check(r.Header.Get("Instill-User-Uid"), qt.Equals, userID.String()) - - w.WriteHeader(http.StatusBadRequest) - }) - - srv := httptest.NewServer(h) - c.Cleanup(srv.Close) - - config, err := structpb.NewStruct(map[string]any{ - "mode": "external", - "server_url": srv.URL, - "api_token": apiKey, - "instill_user_uid": userID.String(), - }) - c.Assert(err, qt.IsNil) - - got, err := connector.Test(defID, config, logger) - c.Check(err, qt.IsNotNil) - c.Check(got, qt.Equals, pipelinePB.Connector_STATE_ERROR) - - wantMsg := "Instill AI responded with a 400 status code. Please refer to Instill AI's API reference for more information." - c.Check(errmsg.Message(err), qt.Equals, wantMsg) - }) - - c.Run("ok - connected", func(c *qt.C) { - h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - c.Check(r.Method, qt.Equals, http.MethodGet) - c.Check(r.URL.Path, qt.Equals, wantPath) - - c.Check(r.Header.Get("Authorization"), qt.Equals, "Bearer "+apiKey) - c.Check(r.Header.Get("Instill-User-Uid"), qt.Equals, userID.String()) - }) - - srv := httptest.NewServer(h) - c.Cleanup(srv.Close) - - config, err := structpb.NewStruct(map[string]any{ - "mode": "external", - "server_url": srv.URL, - "api_token": apiKey, - "instill_user_uid": userID.String(), - }) - c.Assert(err, qt.IsNil) - - got, err := connector.Test(defID, config, logger) - c.Check(err, qt.IsNil) - c.Check(got, qt.Equals, pipelinePB.Connector_STATE_CONNECTED) - }) -} diff --git a/pkg/instill/main.go b/pkg/instill/main.go index 8fbd3ee..ccb7085 100644 --- a/pkg/instill/main.go +++ b/pkg/instill/main.go @@ -2,24 +2,22 @@ package instill import ( "context" - "crypto/tls" _ "embed" - "encoding/json" "fmt" - "net/http" "strings" "sync" "github.com/gofrs/uuid" "go.uber.org/zap" + "google.golang.org/grpc/metadata" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/structpb" "github.com/instill-ai/component/pkg/base" - "github.com/instill-ai/connector/pkg/util/httpclient" commonPB "github.com/instill-ai/protogen-go/common/task/v1alpha" - mgmtv1beta "github.com/instill-ai/protogen-go/core/mgmt/v1beta" + mgmtPB "github.com/instill-ai/protogen-go/core/mgmt/v1beta" + modelPB "github.com/instill-ai/protogen-go/model/model/v1alpha" pipelinePB "github.com/instill-ai/protogen-go/vdp/pipeline/v1beta" ) @@ -77,7 +75,7 @@ func getInstillUserUID(config *structpb.Struct) string { return config.GetFields()["instill_user_uid"].GetStringValue() } -func getServerURL(config *structpb.Struct) string { +func getModelServerURL(config *structpb.Struct) string { if getMode(config) == internalMode { return config.GetFields()["instill_model_backend"].GetStringValue() } @@ -94,6 +92,23 @@ func getServerURL(config *structpb.Struct) string { return serverURL } +func getMgmtServerURL(config *structpb.Struct) string { + if getMode(config) == internalMode { + return config.GetFields()["instill_mgmt_backend"].GetStringValue() + } + serverURL := config.GetFields()["server_url"].GetStringValue() + if strings.HasPrefix(serverURL, "https://") { + if len(strings.Split(serverURL, ":")) == 2 { + serverURL = serverURL + ":443" + } + } else if strings.HasPrefix(serverURL, "http://") { + if len(strings.Split(serverURL, ":")) == 2 { + serverURL = serverURL + ":80" + } + } + return serverURL +} + func (e *Execution) Execute(inputs []*structpb.Struct) ([]*structpb.Struct, error) { var err error @@ -101,25 +116,27 @@ func (e *Execution) Execute(inputs []*structpb.Struct) ([]*structpb.Struct, erro return inputs, fmt.Errorf("invalid input") } - gRPCCLient, gRPCCLientConn := initModelPublicServiceClient(getServerURL(e.Config)) + gRPCCLient, gRPCCLientConn := initModelPublicServiceClient(getModelServerURL(e.Config)) if gRPCCLientConn != nil { defer gRPCCLientConn.Close() } - mgmtGRPCCLient, mgmtGRPCCLientConn := initMgmtPublicServiceClient(getServerURL(e.Config)) + mgmtGRPCCLient, mgmtGRPCCLientConn := initMgmtPublicServiceClient(getMgmtServerURL(e.Config)) if mgmtGRPCCLientConn != nil { defer mgmtGRPCCLientConn.Close() } modelNameSplits := strings.Split(inputs[0].GetFields()["model_name"].GetStringValue(), "/") - nsResp, err := mgmtGRPCCLient.CheckNamespace(context.Background(), &mgmtv1beta.CheckNamespaceRequest{ + md := metadata.Pairs("Authorization", fmt.Sprintf("Bearer %s", getAPIKey(e.Config)), "Instill-User-Uid", getInstillUserUID(e.Config)) + ctx := metadata.NewOutgoingContext(context.Background(), md) + nsResp, err := mgmtGRPCCLient.CheckNamespace(ctx, &mgmtPB.CheckNamespaceRequest{ Id: modelNameSplits[0], }) if err != nil { return nil, err } nsType := "" - if nsResp.Type == mgmtv1beta.CheckNamespaceResponse_NAMESPACE_ORGANIZATION { + if nsResp.Type == mgmtPB.CheckNamespaceResponse_NAMESPACE_ORGANIZATION { nsType = "organizations" } else { nsType = "users" @@ -161,45 +178,20 @@ func (e *Execution) Execute(inputs []*structpb.Struct) ([]*structpb.Struct, erro } func (c *Connector) Test(_ uuid.UUID, config *structpb.Struct, logger *zap.Logger) (pipelinePB.Connector_State, error) { - req := newHTTPClient(config, logger).R() - - path := "/model" + getModelPath - if resp, err := req.Get(path); err != nil || resp.IsError() { + gRPCCLient, gRPCCLientConn := initModelPublicServiceClient(getModelServerURL(config)) + if gRPCCLientConn != nil { + defer gRPCCLientConn.Close() + } + md := metadata.Pairs("Authorization", fmt.Sprintf("Bearer %s", getAPIKey(config)), "Instill-User-Uid", getInstillUserUID(config)) + ctx := metadata.NewOutgoingContext(context.Background(), md) + _, err := gRPCCLient.ListModels(ctx, &modelPB.ListModelsRequest{}) + if err != nil { return pipelinePB.Connector_STATE_ERROR, err } return pipelinePB.Connector_STATE_CONNECTED, nil } -type errBody struct { - Msg string `json:"message"` -} - -func (e errBody) Message() string { - return e.Msg -} - -func newHTTPClient(config *structpb.Struct, logger *zap.Logger) *httpclient.Client { - c := httpclient.New("Instill AI", getServerURL(config), - httpclient.WithLogger(logger), - httpclient.WithEndUserError(new(errBody)), - ) - - c.SetTransport(&http.Transport{ - TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, - DisableKeepAlives: true, - }) - - if token := getAPIKey(config); token != "" { - c.SetAuthToken(token) - } - - if userID := getInstillUserUID(config); userID != "" { - c.SetHeader("Instill-User-Uid", userID) - } - - return c -} func (c *Connector) GetConnectorDefinitionByID(defID string, resourceConfig *structpb.Struct, componentConfig *structpb.Struct) (*pipelinePB.ConnectorDefinition, error) { def, err := c.Connector.GetConnectorDefinitionByID(defID, resourceConfig, componentConfig) if err != nil { @@ -225,24 +217,27 @@ func (c *Connector) GetConnectorDefinitionByUID(defUID uuid.UUID, resourceConfig def := proto.Clone(oriDef).(*pipelinePB.ConnectorDefinition) if resourceConfig != nil { - req := newHTTPClient(resourceConfig, c.Logger).R() - var modelsResp ModelsResp - path := "/model" + getModelPath + gRPCCLient, gRPCCLientConn := initModelPublicServiceClient(getModelServerURL(resourceConfig)) + if gRPCCLientConn != nil { + defer gRPCCLientConn.Close() + } + md := metadata.Pairs("Authorization", fmt.Sprintf("Bearer %s", getAPIKey(resourceConfig)), "Instill-User-Uid", getInstillUserUID(resourceConfig)) + ctx := metadata.NewOutgoingContext(context.Background(), md) + resp, err := gRPCCLient.ListModels(ctx, &modelPB.ListModelsRequest{}) + if err != nil { + return def, nil + } modelNameMap := map[string]*structpb.ListValue{} - if resp, err := req.Get(path); err == nil && !resp.IsError() { - _ = json.Unmarshal(resp.Body(), &modelsResp) - modelName := &structpb.ListValue{} - for _, model := range modelsResp.Models { - if _, ok := modelNameMap[model.Task]; !ok { - modelNameMap[model.Task] = &structpb.ListValue{} - } - namePaths := strings.Split(model.Name, "/") - modelName.Values = append(modelName.Values, structpb.NewStringValue(fmt.Sprintf("%s/%s", namePaths[1], namePaths[3]))) - modelNameMap[model.Task].Values = append(modelNameMap[model.Task].Values, structpb.NewStringValue(fmt.Sprintf("%s/%s", namePaths[1], namePaths[3]))) + modelName := &structpb.ListValue{} + for _, model := range resp.Models { + if _, ok := modelNameMap[model.Task.String()]; !ok { + modelNameMap[model.Task.String()] = &structpb.ListValue{} } - + namePaths := strings.Split(model.Name, "/") + modelName.Values = append(modelName.Values, structpb.NewStringValue(fmt.Sprintf("%s/%s", namePaths[1], namePaths[3]))) + modelNameMap[model.Task.String()].Values = append(modelNameMap[model.Task.String()].Values, structpb.NewStringValue(fmt.Sprintf("%s/%s", namePaths[1], namePaths[3]))) } for _, sch := range def.Spec.ComponentSpecification.Fields["oneOf"].GetListValue().Values { task := sch.GetStructValue().Fields["properties"].GetStructValue().Fields["task"].GetStructValue().Fields["const"].GetStringValue()