Skip to content
This repository has been archived by the owner on Mar 29, 2024. It is now read-only.

Commit

Permalink
feat(instill): use grpc client for all request (#108)
Browse files Browse the repository at this point in the history
Because

- we want to use the same connection approach for getting models,
triggering and testing connection.

This commit

- use grpc client for getting models and testing connection
- remove the unit test for `Test()`
  • Loading branch information
donch1989 committed Jan 14, 2024
1 parent 96d2107 commit 548a78d
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 140 deletions.
83 changes: 0 additions & 83 deletions pkg/instill/connector_test.go

This file was deleted.

109 changes: 52 additions & 57 deletions pkg/instill/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,22 @@ package instill

import (
"context"
"crypto/tls"
_ "embed"
"encoding/json"
"fmt"
"net/http"
"strings"
"sync"

"github.com/gofrs/uuid"
"go.uber.org/zap"
"google.golang.org/grpc/metadata"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/structpb"

"github.com/instill-ai/component/pkg/base"
"github.com/instill-ai/connector/pkg/util/httpclient"

commonPB "github.com/instill-ai/protogen-go/common/task/v1alpha"
mgmtv1beta "github.com/instill-ai/protogen-go/core/mgmt/v1beta"
mgmtPB "github.com/instill-ai/protogen-go/core/mgmt/v1beta"
modelPB "github.com/instill-ai/protogen-go/model/model/v1alpha"
pipelinePB "github.com/instill-ai/protogen-go/vdp/pipeline/v1beta"
)

Expand Down Expand Up @@ -77,7 +75,7 @@ func getInstillUserUID(config *structpb.Struct) string {
return config.GetFields()["instill_user_uid"].GetStringValue()
}

func getServerURL(config *structpb.Struct) string {
func getModelServerURL(config *structpb.Struct) string {
if getMode(config) == internalMode {
return config.GetFields()["instill_model_backend"].GetStringValue()
}
Expand All @@ -94,32 +92,51 @@ func getServerURL(config *structpb.Struct) string {
return serverURL
}

func getMgmtServerURL(config *structpb.Struct) string {
if getMode(config) == internalMode {
return config.GetFields()["instill_mgmt_backend"].GetStringValue()
}
serverURL := config.GetFields()["server_url"].GetStringValue()
if strings.HasPrefix(serverURL, "https://") {
if len(strings.Split(serverURL, ":")) == 2 {
serverURL = serverURL + ":443"
}
} else if strings.HasPrefix(serverURL, "http://") {
if len(strings.Split(serverURL, ":")) == 2 {
serverURL = serverURL + ":80"
}
}
return serverURL
}

func (e *Execution) Execute(inputs []*structpb.Struct) ([]*structpb.Struct, error) {
var err error

if len(inputs) <= 0 || inputs[0] == nil {
return inputs, fmt.Errorf("invalid input")
}

gRPCCLient, gRPCCLientConn := initModelPublicServiceClient(getServerURL(e.Config))
gRPCCLient, gRPCCLientConn := initModelPublicServiceClient(getModelServerURL(e.Config))
if gRPCCLientConn != nil {
defer gRPCCLientConn.Close()
}

mgmtGRPCCLient, mgmtGRPCCLientConn := initMgmtPublicServiceClient(getServerURL(e.Config))
mgmtGRPCCLient, mgmtGRPCCLientConn := initMgmtPublicServiceClient(getMgmtServerURL(e.Config))
if mgmtGRPCCLientConn != nil {
defer mgmtGRPCCLientConn.Close()
}

modelNameSplits := strings.Split(inputs[0].GetFields()["model_name"].GetStringValue(), "/")
nsResp, err := mgmtGRPCCLient.CheckNamespace(context.Background(), &mgmtv1beta.CheckNamespaceRequest{
md := metadata.Pairs("Authorization", fmt.Sprintf("Bearer %s", getAPIKey(e.Config)), "Instill-User-Uid", getInstillUserUID(e.Config))
ctx := metadata.NewOutgoingContext(context.Background(), md)
nsResp, err := mgmtGRPCCLient.CheckNamespace(ctx, &mgmtPB.CheckNamespaceRequest{
Id: modelNameSplits[0],
})
if err != nil {
return nil, err
}
nsType := ""
if nsResp.Type == mgmtv1beta.CheckNamespaceResponse_NAMESPACE_ORGANIZATION {
if nsResp.Type == mgmtPB.CheckNamespaceResponse_NAMESPACE_ORGANIZATION {
nsType = "organizations"
} else {
nsType = "users"
Expand Down Expand Up @@ -161,45 +178,20 @@ func (e *Execution) Execute(inputs []*structpb.Struct) ([]*structpb.Struct, erro
}

func (c *Connector) Test(_ uuid.UUID, config *structpb.Struct, logger *zap.Logger) (pipelinePB.Connector_State, error) {
req := newHTTPClient(config, logger).R()

path := "/model" + getModelPath
if resp, err := req.Get(path); err != nil || resp.IsError() {
gRPCCLient, gRPCCLientConn := initModelPublicServiceClient(getModelServerURL(config))
if gRPCCLientConn != nil {
defer gRPCCLientConn.Close()
}
md := metadata.Pairs("Authorization", fmt.Sprintf("Bearer %s", getAPIKey(config)), "Instill-User-Uid", getInstillUserUID(config))
ctx := metadata.NewOutgoingContext(context.Background(), md)
_, err := gRPCCLient.ListModels(ctx, &modelPB.ListModelsRequest{})
if err != nil {
return pipelinePB.Connector_STATE_ERROR, err
}

return pipelinePB.Connector_STATE_CONNECTED, nil
}

type errBody struct {
Msg string `json:"message"`
}

func (e errBody) Message() string {
return e.Msg
}

func newHTTPClient(config *structpb.Struct, logger *zap.Logger) *httpclient.Client {
c := httpclient.New("Instill AI", getServerURL(config),
httpclient.WithLogger(logger),
httpclient.WithEndUserError(new(errBody)),
)

c.SetTransport(&http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
DisableKeepAlives: true,
})

if token := getAPIKey(config); token != "" {
c.SetAuthToken(token)
}

if userID := getInstillUserUID(config); userID != "" {
c.SetHeader("Instill-User-Uid", userID)
}

return c
}
func (c *Connector) GetConnectorDefinitionByID(defID string, resourceConfig *structpb.Struct, componentConfig *structpb.Struct) (*pipelinePB.ConnectorDefinition, error) {
def, err := c.Connector.GetConnectorDefinitionByID(defID, resourceConfig, componentConfig)
if err != nil {
Expand All @@ -225,24 +217,27 @@ func (c *Connector) GetConnectorDefinitionByUID(defUID uuid.UUID, resourceConfig
def := proto.Clone(oriDef).(*pipelinePB.ConnectorDefinition)

if resourceConfig != nil {
req := newHTTPClient(resourceConfig, c.Logger).R()
var modelsResp ModelsResp
path := "/model" + getModelPath
gRPCCLient, gRPCCLientConn := initModelPublicServiceClient(getModelServerURL(resourceConfig))
if gRPCCLientConn != nil {
defer gRPCCLientConn.Close()
}
md := metadata.Pairs("Authorization", fmt.Sprintf("Bearer %s", getAPIKey(resourceConfig)), "Instill-User-Uid", getInstillUserUID(resourceConfig))
ctx := metadata.NewOutgoingContext(context.Background(), md)
resp, err := gRPCCLient.ListModels(ctx, &modelPB.ListModelsRequest{})
if err != nil {
return def, nil
}

modelNameMap := map[string]*structpb.ListValue{}
if resp, err := req.Get(path); err == nil && !resp.IsError() {
_ = json.Unmarshal(resp.Body(), &modelsResp)

modelName := &structpb.ListValue{}
for _, model := range modelsResp.Models {
if _, ok := modelNameMap[model.Task]; !ok {
modelNameMap[model.Task] = &structpb.ListValue{}
}
namePaths := strings.Split(model.Name, "/")
modelName.Values = append(modelName.Values, structpb.NewStringValue(fmt.Sprintf("%s/%s", namePaths[1], namePaths[3])))
modelNameMap[model.Task].Values = append(modelNameMap[model.Task].Values, structpb.NewStringValue(fmt.Sprintf("%s/%s", namePaths[1], namePaths[3])))
modelName := &structpb.ListValue{}
for _, model := range resp.Models {
if _, ok := modelNameMap[model.Task.String()]; !ok {
modelNameMap[model.Task.String()] = &structpb.ListValue{}
}

namePaths := strings.Split(model.Name, "/")
modelName.Values = append(modelName.Values, structpb.NewStringValue(fmt.Sprintf("%s/%s", namePaths[1], namePaths[3])))
modelNameMap[model.Task.String()].Values = append(modelNameMap[model.Task.String()].Values, structpb.NewStringValue(fmt.Sprintf("%s/%s", namePaths[1], namePaths[3])))
}
for _, sch := range def.Spec.ComponentSpecification.Fields["oneOf"].GetListValue().Values {
task := sch.GetStructValue().Fields["properties"].GetStructValue().Fields["task"].GetStructValue().Fields["const"].GetStringValue()
Expand Down

0 comments on commit 548a78d

Please sign in to comment.