diff --git a/README.md b/README.md index c746eaa..5f35752 100644 --- a/README.md +++ b/README.md @@ -7,8 +7,9 @@ Currently, we support these AI vendors: | Name | Vendor | Supported Tasks | |--------------------|----------------------------------------------------------|-----------------| +| ai-archetype-ai | [Archetype AI](https://www.archetypeai.io/) | Image summarization, Video description | | ai-instill-model | [Instill Model](https://instill.tech/) | Classification, Instance-Segmentation, Keypoint, Detection, OCR, Semantic-Segmentation, Text-Generation, Text-to-Image| -| ai-openai | [OpenAI](https://openai.com) | Text-Generation, Speech-Recognition, Text-Embedding +| ai-openai | [OpenAI](https://openai.com) | Text-Generation, Speech-Recognition, Text-Embedding | | ai-stability-ai | [Stability AI](https://stability.ai/) | Text-to-Image, Image-to-Image | | airbyte-* | [Aiybyte](https://airbyte.com/) | WriteDestination | | data-pinecone | [Pinecone](https://www.pinecone.io/) | Upsert, Query | diff --git a/pkg/archetypeai/client.go b/pkg/archetypeai/client.go new file mode 100644 index 0000000..5844a9a --- /dev/null +++ b/pkg/archetypeai/client.go @@ -0,0 +1,33 @@ +package archetypeai + +import ( + "github.com/instill-ai/connector/pkg/util/httpclient" + "go.uber.org/zap" + "google.golang.org/protobuf/types/known/structpb" +) + +const ( + host = "https://api.archetypeai.dev" + describePath = "/v0.3/describe" + summarizePath = "/v0.3/summarize" + uploadFilePath = "/v0.3/files" +) + +func newClient(config *structpb.Struct, logger *zap.Logger) *httpclient.Client { + c := httpclient.New("Archetype AI", getBasePath(config), + httpclient.WithLogger(logger), + httpclient.WithEndUserError(new(errBody)), + ) + + c.SetAuthToken(getAPIKey(config)) + + return c +} + +type errBody struct { + Error string `json:"error"` +} + +func (e errBody) Message() string { + return e.Error +} diff --git a/pkg/archetypeai/config/definitions.json b/pkg/archetypeai/config/definitions.json new file mode 100644 index 0000000..56274fa --- /dev/null +++ b/pkg/archetypeai/config/definitions.json @@ -0,0 +1,42 @@ +[ + { + "available_tasks": [ + "TASK_DESCRIBE", + "TASK_SUMMARIZE", + "TASK_UPLOAD_FILE" + ], + "custom": false, + "documentation_url": "", + "icon": "Instill AI/website.svg", + "icon": "Archetype AI/archetypeai.svg", + "icon_url": "", + "id": "archetype-ai", + "public": true, + "spec": { + "resource_specification": { + "$schema": "http://json-schema.org/draft-07/schema#", + "additionalProperties": false, + "properties": { + "api_key": { + "description": "Fill your Archetype AI API key", + "instillCredentialField": true, + "instillUIOrder": 0, + "title": "API Key", + "type": "string" + } + }, + "required": [ + "api_key" + ], + "title": "Archetype AI Connector Specification", + "type": "object" + } + }, + "title": "Archetype AI", + "tombstone": false, + "type": "CONNECTOR_TYPE_AI", + "uid": "e414a1f8-5fdf-4292-b050-9f9176254a4b", + "vendor": "Archetype AI", + "vendor_attributes": {} + } +] diff --git a/pkg/archetypeai/config/tasks.json b/pkg/archetypeai/config/tasks.json new file mode 100644 index 0000000..02699c4 --- /dev/null +++ b/pkg/archetypeai/config/tasks.json @@ -0,0 +1,201 @@ +{ + "TASK_DESCRIBE": { + "input": { + "instillUIOrder": 0, + "properties": { + "query": { + "description": "A guide to describe the video", + "instillAcceptFormats": [ + "string" + ], + "instillUIMultiline": true, + "instillUIOrder": 0, + "instillUpstreamTypes": [ + "value", + "reference", + "template" + ], + "title": "Query", + "type": "string" + }, + "file_ids": { + "description": "The IDs of the videos to describe. These must have been previously uploaded via TASK_UPLOAD_FILE.", + "instillAcceptFormats": [ + "array:string" + ], + "instillUIOrder": 1, + "instillUpstreamTypes": [ + "value", + "reference" + ], + "items": { + "instillUIMultiline": false, + "type": "string" + }, + "minItems": 1, + "title": "File IDs", + "type": "array" + } + }, + "required": [ + "query", + "file_ids" + ], + "title": "Input", + "type": "object" + }, + "output": { + "instillUIOrder": 0, + "properties": { + "descriptions": { + "description": "A set of descriptions corresponding to different moments in the video", + "instillUIOrder": 0, + "title": "Descriptions", + "type": "array", + "items": { + "title": "Frame description", + "type": "object", + "properties": { + "frame_id": { + "description": "The frame number in the video that is being described", + "instillFormat": "integer", + "instillUIOrder": 3, + "required": [], + "title": "Frame ID", + "type": "integer" + }, + "timestamp": { + "description": "The moment of the video (in seconds since the start) that is being described", + "instillFormat": "number", + "instillUIOrder": 1, + "title": "Timestamp", + "type": "number" + }, + "description": { + "description": "The description of the frame", + "instillFormat": "string", + "instillUIOrder": 2, + "title": "Description", + "type": "string" + } + }, + "required": [ + "description", + "timestamp", + "frame_id" + ] + } + } + }, + "required": [ + "descriptions" + ], + "title": "Output", + "type": "object" + } + }, + "TASK_SUMMARIZE": { + "input": { + "instillUIOrder": 0, + "properties": { + "query": { + "description": "A guide to summarize the image", + "instillAcceptFormats": [ + "string" + ], + "instillUIMultiline": true, + "instillUIOrder": 0, + "instillUpstreamTypes": [ + "value", + "reference", + "template" + ], + "title": "Query", + "type": "string" + }, + "file_ids": { + "description": "The IDs of the images to summarize. These must have been previously uploaded via TASK_UPLOAD_FILE.", + "instillAcceptFormats": [ + "array:string" + ], + "instillUIOrder": 1, + "instillUpstreamTypes": [ + "value", + "reference" + ], + "items": { + "instillUIMultiline": false, + "type": "string" + }, + "minItems": 1, + "title": "File IDs", + "type": "array" + } + }, + "required": [ + "query", + "file_ids" + ], + "title": "Input", + "type": "object" + }, + "output": { + "instillUIOrder": 0, + "properties": { + "response": { + "description": "A text responding to the query", + "instillFormat": "string", + "instillUIOrder": 0, + "title": "Response", + "type": "string" + } + }, + "required": [ + "response" + ], + "title": "Output", + "type": "object" + } + }, + "TASK_UPLOAD_FILE": { + "input": { + "instillUIOrder": 0, + "properties": { + "file": { + "title": "File", + "description": "The file to upload. Accepted formats are JPEG and PNG for images or MP4 for videos", + "type": "string", + "instillAcceptFormats": [ + "image/*" + ], + "instillUIOrder": 0, + "instillUpstreamTypes": [ + "reference" + ] + } + }, + "required": [ + "file" + ], + "title": "Input", + "type": "object" + }, + "output": { + "instillUIOrder": 0, + "properties": { + "file_id": { + "instillFormat": "string", + "instillUIOrder": 0, + "title": "File ID", + "description": "The ID to reference the file in queries", + "type": "string" + } + }, + "required": [ + "file_id" + ], + "title": "Output", + "type": "object" + } + } +} diff --git a/pkg/archetypeai/connector_test.go b/pkg/archetypeai/connector_test.go new file mode 100644 index 0000000..e8a2521 --- /dev/null +++ b/pkg/archetypeai/connector_test.go @@ -0,0 +1,316 @@ +package archetypeai + +import ( + "fmt" + "io" + "net/http" + "net/http/httptest" + "testing" + + qt "github.com/frankban/quicktest" + "github.com/gofrs/uuid" + "github.com/instill-ai/component/pkg/base" + "github.com/instill-ai/connector/pkg/util/httpclient" + pb "github.com/instill-ai/protogen-go/vdp/pipeline/v1beta" + "github.com/instill-ai/x/errmsg" + "go.uber.org/zap" + "google.golang.org/protobuf/types/known/structpb" +) + +const ( + apiKey = "213bac" +) + +const errJSON = `{ "error": "Invalid access." }` +const describeJSON = ` +{ + "query_id": "2401242b4d59e48bbf6e0d", + "status": "completed", + "inference_time_sec": 1.6635565757751465, + "query_response_time_sec": 6.018876314163208, + "response": [ + { + "timestamp": 2.0, + "frame_id": 60, + "description": "The group of people is walking across a bridge." + }, + { + "timestamp": 6.0, + "frame_id": 180, + "description": "The man is walking across a bridge, and he is surrounded by people." + } + ] +}` +const describeErrJSON = ` +{ + "query_id": "2401242b4d59e48bbf6e0d", + "status": "failed", + "inference_time_sec": 1.6635565757751465, + "query_response_time_sec": 6.018876314163208, + "response": [ + { + "timestamp": 2.0, + "frame_id": 60, + "description": "The group of people is walking across a bridge." + } + ] +}` +const summarizeJSON = ` +{ + "query_id": "240123b93a83a79e9907a5", + "status": "completed", + "file_ids": [ + "test_image.jpg" + ], + "inference_time_sec": 2.1776912212371826, + "query_response_time_sec": 2.1914472579956055, + "response": { + "processed_text": "A family of four is hiking together on a trail." + } +}` +const summarizeErrJSON = ` +{ + "query_id": "2401233472bde249e60260", + "status": "failed", + "file_ids": [ + "test_image.jpg" + ] +}` +const uploadFileJSON = ` +{ + "is_valid": true, + "file_id": "2084fa42-8452-4fa6-bed9-6aac6d6153bb", + "file_uid": "2401242e3cb25122835a17" +}` +const uploadErrJSON = ` +{ + "is_valid": false, + "errors": [ + "Invalid file type: application/octet-stream. Supported file types are: ('image/jpeg', 'image/png', 'video/mp4')." + ] +}` + +var ( + queryIn = fileQueryParams{ + Query: "Describe what's happening", + FileIDs: []string{"test.file"}, + } + uploadFileIn = uploadFileParams{ + File: "data:text/plain;base64,aG9sYQ==", + } +) + +func TestConnector_Execute(t *testing.T) { + c := qt.New(t) + + testcases := []struct { + name string + + task string + in any + want any + wantErr string + + // server expectations and response + wantPath string + wantReq any + wantContentType string + gotStatus int + gotResp string + }{ + { + name: "ok - describe", + + task: taskDescribe, + in: queryIn, + want: describeOutput{ + Descriptions: []frameDescription{ + { + Timestamp: 2.0, + FrameID: 60, + Description: "The group of people is walking across a bridge.", + }, + { + Timestamp: 6.0, + FrameID: 180, + Description: "The man is walking across a bridge, and he is surrounded by people.", + }, + }, + }, + + wantPath: describePath, + wantReq: queryIn, + wantContentType: httpclient.MIMETypeJSON, + gotStatus: http.StatusOK, + gotResp: describeJSON, + }, + { + name: "nok - describe error", + + task: taskDescribe, + in: queryIn, + wantErr: `Archetype AI didn't complete query 2401242b4d59e48bbf6e0d: status is "failed".`, + + wantPath: describePath, + wantReq: queryIn, + wantContentType: httpclient.MIMETypeJSON, + gotStatus: http.StatusOK, + gotResp: describeErrJSON, + }, + { + name: "ok - summarize", + + task: taskSummarize, + in: queryIn, + want: summarizeOutput{ + Response: "A family of four is hiking together on a trail.", + }, + + wantPath: summarizePath, + wantReq: queryIn, + wantContentType: httpclient.MIMETypeJSON, + gotStatus: http.StatusOK, + gotResp: summarizeJSON, + }, + { + name: "nok - summarize wrong file", + + task: taskSummarize, + in: queryIn, + wantErr: `Archetype AI didn't complete query 2401233472bde249e60260: status is "failed".`, + + wantPath: summarizePath, + wantReq: queryIn, + wantContentType: httpclient.MIMETypeJSON, + gotStatus: http.StatusOK, + gotResp: summarizeErrJSON, + }, + { + name: "ok - upload file", + + task: taskUploadFile, + in: uploadFileIn, + want: uploadFileOutput{FileID: "2084fa42-8452-4fa6-bed9-6aac6d6153bb"}, + + wantPath: uploadFilePath, + wantReq: "hola", + wantContentType: "multipart/form-data.*", + gotStatus: http.StatusOK, + gotResp: uploadFileJSON, + }, + { + name: "nok - upload invalid file", + + task: taskUploadFile, + in: uploadFileIn, + wantErr: "Couldn't complete upload: Invalid file type.*", + + wantPath: uploadFilePath, + wantReq: "hola", + wantContentType: "multipart/form-data.*", + gotStatus: http.StatusOK, + gotResp: uploadErrJSON, + }, + { + name: "nok - unauthorized", + + task: taskSummarize, + in: queryIn, + wantErr: "Archetype AI responded with a 401 status code. Invalid access.", + + wantPath: summarizePath, + wantReq: queryIn, + wantContentType: httpclient.MIMETypeJSON, + gotStatus: http.StatusUnauthorized, + gotResp: errJSON, + }, + } + + logger := zap.NewNop() + connector := Init(logger) + defID := uuid.Must(uuid.NewV4()) + + for _, tc := range testcases { + c.Run(tc.name, func(c *qt.C) { + h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + c.Check(r.Method, qt.Equals, http.MethodPost) + c.Check(r.URL.Path, qt.Matches, tc.wantPath) + + c.Check(r.Header.Get("Authorization"), qt.Equals, "Bearer "+apiKey) + c.Check(r.Header.Get("Content-Type"), qt.Matches, tc.wantContentType) + + body, err := io.ReadAll(r.Body) + c.Assert(err, qt.IsNil) + if tc.wantContentType == httpclient.MIMETypeJSON { + c.Check(body, qt.JSONEquals, tc.wantReq) + } else { + // We just do partial match to avoid matching every field + // in multipart bodies. + c.Check(string(body), qt.Contains, tc.wantReq) + } + + w.Header().Set("Content-Type", httpclient.MIMETypeJSON) + w.WriteHeader(tc.gotStatus) + fmt.Fprintln(w, tc.gotResp) + }) + + srv := httptest.NewServer(h) + c.Cleanup(srv.Close) + + config, _ := structpb.NewStruct(map[string]any{ + "base_path": srv.URL, + "api_key": apiKey, + }) + + exec, err := connector.CreateExecution(defID, tc.task, config, logger) + c.Assert(err, qt.IsNil) + + pbIn, err := base.ConvertToStructpb(tc.in) + c.Assert(err, qt.IsNil) + + got, err := exec.Execute([]*structpb.Struct{pbIn}) + if tc.wantErr != "" { + c.Check(errmsg.Message(err), qt.Matches, tc.wantErr) + return + } + + c.Check(err, qt.IsNil) + c.Assert(got, qt.HasLen, 1) + + gotJSON, err := got[0].MarshalJSON() + c.Assert(err, qt.IsNil) + c.Check(gotJSON, qt.JSONEquals, tc.want) + }) + } +} + +func TestConnector_CreateExecution(t *testing.T) { + c := qt.New(t) + + logger := zap.NewNop() + connector := Init(logger) + defID := uuid.Must(uuid.NewV4()) + + c.Run("nok - unsupported task", func(c *qt.C) { + task := "FOOBAR" + want := fmt.Sprintf("%s task is not supported.", task) + + _, err := connector.CreateExecution(defID, task, new(structpb.Struct), logger) + c.Check(err, qt.IsNotNil) + c.Check(errmsg.Message(err), qt.Equals, want) + }) +} + +func TestConnector_Test(t *testing.T) { + c := qt.New(t) + + logger := zap.NewNop() + connector := Init(logger) + defID := uuid.Must(uuid.NewV4()) + + c.Run("ok - connected", func(c *qt.C) { + got, err := connector.Test(defID, nil, logger) + c.Check(err, qt.IsNil) + c.Check(got, qt.Equals, pb.Connector_STATE_CONNECTED) + }) +} diff --git a/pkg/archetypeai/main.go b/pkg/archetypeai/main.go new file mode 100644 index 0000000..d2560a5 --- /dev/null +++ b/pkg/archetypeai/main.go @@ -0,0 +1,245 @@ +package archetypeai + +import ( + "bytes" + _ "embed" + "fmt" + "strings" + "sync" + + "go.uber.org/zap" + "google.golang.org/protobuf/types/known/structpb" + + "github.com/gofrs/uuid" + "github.com/instill-ai/component/pkg/base" + "github.com/instill-ai/connector/pkg/util" + "github.com/instill-ai/connector/pkg/util/httpclient" + "github.com/instill-ai/x/errmsg" + + pb "github.com/instill-ai/protogen-go/vdp/pipeline/v1beta" +) + +const ( + taskDescribe = "TASK_DESCRIBE" + taskSummarize = "TASK_SUMMARIZE" + taskUploadFile = "TASK_UPLOAD_FILE" +) + +var ( + //go:embed config/definitions.json + definitionsJSON []byte + //go:embed config/tasks.json + tasksJSON []byte + + once sync.Once + baseConn base.IConnector +) + +type connector struct { + base.Connector +} + +type execution struct { + base.Execution + execute func(*structpb.Struct) (*structpb.Struct, error) + client *httpclient.Client +} + +// Init returns an implementation of IConnector that interacts with Archetype +// AI. +func Init(logger *zap.Logger) base.IConnector { + once.Do(func() { + baseConn = &connector{ + Connector: base.Connector{ + Component: base.Component{Logger: logger}, + }, + } + if err := baseConn.LoadConnectorDefinitions(definitionsJSON, tasksJSON, nil); err != nil { + logger.Fatal(err.Error()) + } + }) + + return baseConn +} + +// CreateExecution returns an IExecution that executes tasks in Archetype AI. +func (c *connector) CreateExecution(defUID uuid.UUID, task string, config *structpb.Struct, logger *zap.Logger) (base.IExecution, error) { + e := &execution{ + client: newClient(config, logger), + } + + switch task { + case taskDescribe: + e.execute = e.describe + case taskSummarize: + e.execute = e.summarize + case taskUploadFile: + e.execute = e.uploadFile + default: + return nil, errmsg.AddMessage( + fmt.Errorf("not supported task: %s", task), + fmt.Sprintf("%s task is not supported.", task), + ) + } + + e.Execution = base.CreateExecutionHelper(e, c, defUID, task, config, logger) + + return e, nil +} + +// Execute performs calls the Archetype AI API to execute a task. +func (e *execution) Execute(inputs []*structpb.Struct) ([]*structpb.Struct, error) { + outputs := make([]*structpb.Struct, len(inputs)) + + for i, input := range inputs { + output, err := e.execute(input) + if err != nil { + return nil, err + } + + outputs[i] = output + } + + return outputs, nil +} + +func (e *execution) describe(in *structpb.Struct) (*structpb.Struct, error) { + params := fileQueryParams{} + if err := base.ConvertFromStructpb(in, ¶ms); err != nil { + return nil, err + } + + // We have a 1-1 mapping between the VDP user input and the Archetype AI + // request. If this stops being the case in the future, we'll need a + // describeReq structure. + resp := describeResp{} + req := e.client.R().SetBody(params).SetResult(&resp) + + if _, err := req.Post(describePath); err != nil { + return nil, err + } + + // Archetype AI might return a 200 status even if the operation failed + // (e.g. if the file doesn't exist). + if resp.Status != statusCompleted { + return nil, errmsg.AddMessage( + fmt.Errorf("response with non-completed status"), + fmt.Sprintf(`Archetype AI didn't complete query %s: status is "%s".`, resp.QueryID, resp.Status), + ) + } + + out, err := base.ConvertToStructpb(describeOutput{ + Descriptions: resp.Response, + }) + if err != nil { + return nil, err + } + + return out, nil +} + +func (e *execution) summarize(in *structpb.Struct) (*structpb.Struct, error) { + params := fileQueryParams{} + if err := base.ConvertFromStructpb(in, ¶ms); err != nil { + return nil, err + } + + // We have a 1-1 mapping between the VDP user input and the Archetype AI + // request. If this stops being the case in the future, we'll need a + // summarizeReq structure. + resp := summarizeResp{} + req := e.client.R().SetBody(params).SetResult(&resp) + + if _, err := req.Post(summarizePath); err != nil { + return nil, err + } + + // Archetype AI might return a 200 status even if the operation failed + // (e.g. if the file doesn't exist). + if resp.Status != statusCompleted { + return nil, errmsg.AddMessage( + fmt.Errorf("response with non-completed status"), + fmt.Sprintf(`Archetype AI didn't complete query %s: status is "%s".`, resp.QueryID, resp.Status), + ) + } + + out, err := base.ConvertToStructpb(summarizeOutput{ + Response: resp.Response.ProcessedText, + }) + if err != nil { + return nil, err + } + + return out, nil +} + +func (e *execution) uploadFile(in *structpb.Struct) (*structpb.Struct, error) { + params := uploadFileParams{} + if err := base.ConvertFromStructpb(in, ¶ms); err != nil { + return nil, err + } + + resp := uploadFileResp{} + req := e.client.R().SetResult(&resp) + + b, err := util.DecodeBase64(params.File) + if err != nil { + return nil, err + } + + id, err := uuid.NewV4() + if err != nil { + return nil, err + } + + req.SetFileReader("file", id.String(), bytes.NewReader(b)) + if _, err := req.Post(uploadFilePath); err != nil { + return nil, err + } + + if !resp.IsValid { + errMsg := "invalid file." + if len(resp.Errors) > 0 { + errMsg = strings.Join(resp.Errors, " ") + } + + return nil, errmsg.AddMessage( + fmt.Errorf("file upload failed"), + fmt.Sprintf(`Couldn't complete upload: %s`, errMsg), + ) + } + + out, err := base.ConvertToStructpb(resp.uploadFileOutput) + if err != nil { + return nil, err + } + + return out, nil +} + +// Test checks the connectivity of the connector. +func (c *connector) Test(_ uuid.UUID, _ *structpb.Struct, _ *zap.Logger) (pb.Connector_State, error) { + // TODO Archetype AI API is not public yet. We could test the connection + // by calling one of the endpoints used in the available tasks. However, + // these are not designed for specifically for this purpose. When we know + // of an endpoint that's more suited for this, it should be used instead. + return pb.Connector_STATE_CONNECTED, nil +} + +func getAPIKey(config *structpb.Struct) string { + return config.GetFields()["api_key"].GetStringValue() +} + +// getBasePath returns Archetype AI's API URL. This configuration param allows +// us to override the API the connector will point to. It isn't meant to be +// exposed to users. Rather, it can serve to test the logic against a fake +// server. +// TODO instead of having the API value hardcoded in the codebase, it should +// be read from a config file or environment variable. +func getBasePath(config *structpb.Struct) string { + v, ok := config.GetFields()["base_path"] + if !ok { + return host + } + return v.GetStringValue() +} diff --git a/pkg/archetypeai/structs.go b/pkg/archetypeai/structs.go new file mode 100644 index 0000000..4718c2e --- /dev/null +++ b/pkg/archetypeai/structs.go @@ -0,0 +1,64 @@ +package archetypeai + +// fileQueryParams holds a query about an file. It is used as the input in +// e.g. video description or image summarization tasks. +type fileQueryParams struct { + Query string `json:"query"` + FileIDs []string `json:"file_ids"` +} + +// summarizeOutput is used to return the output of a TASK_SUMMARIZE execution. +type summarizeOutput struct { + Response string `json:"response"` +} + +const ( + statusCompleted = "completed" + statusFailed = "failed" +) + +// summarizeResp holds the response from the Archetype AI API call. +type summarizeResp struct { + QueryID string `json:"query_id"` + Status string `json:"status"` + Response struct { + ProcessedText string `json:"processed_text"` + } `json:"response"` +} + +type frameDescription struct { + Timestamp float32 `json:"timestamp"` + FrameID uint64 `json:"frame_id"` + Description string `json:"description"` +} + +// describeResp holds the response from the Archetype AI API call. +type describeResp struct { + QueryID string `json:"query_id"` + Status string `json:"status"` + Response []frameDescription `json:"response"` +} + +// summarizeOutput is used to return the output of a TASK_DESCRIBE execution. +type describeOutput struct { + Descriptions []frameDescription `json:"descriptions"` +} + +// uploadFileParams holds the input of a file upload task. +type uploadFileParams struct { + File string `json:"file"` +} + +// uploadFileOutput is used to return the output of a file TASK_UPLOAD_FILE +// execution. +type uploadFileOutput struct { + FileID string `json:"file_id"` +} + +// uploadFileResp holds the response from the Archetype AI API call. +type uploadFileResp struct { + uploadFileOutput + + IsValid bool `json:"is_valid"` + Errors []string `json:"errors"` +} diff --git a/pkg/integration_test.go b/pkg/integration_test.go index fda8e26..b168c4c 100644 --- a/pkg/integration_test.go +++ b/pkg/integration_test.go @@ -42,7 +42,7 @@ func TestOpenAITextGeneration(t *testing.T) { logger := zap.NewNop() conn := Init(logger, emptyOptions) - def, err := conn.GetConnectorDefinitionByID("openai") + def, err := conn.GetConnectorDefinitionByID("openai", nil, nil) c.Assert(err, qt.IsNil) uid, err := uuid.FromString(def.GetUid()) diff --git a/pkg/main.go b/pkg/main.go index 20876c2..d7ccf8a 100644 --- a/pkg/main.go +++ b/pkg/main.go @@ -9,6 +9,7 @@ import ( "github.com/instill-ai/component/pkg/base" "github.com/instill-ai/connector/pkg/airbyte" + "github.com/instill-ai/connector/pkg/archetypeai" "github.com/instill-ai/connector/pkg/bigquery" "github.com/instill-ai/connector/pkg/googlecloudstorage" "github.com/instill-ai/connector/pkg/googlesearch" @@ -51,6 +52,7 @@ func Init(logger *zap.Logger, options ConnectorOptions) base.IConnector { connector.(*Connector).ImportDefinitions(instill.Init(logger)) connector.(*Connector).ImportDefinitions(huggingface.Init(logger)) connector.(*Connector).ImportDefinitions(openai.Init(logger)) + connector.(*Connector).ImportDefinitions(archetypeai.Init(logger)) connector.(*Connector).ImportDefinitions(numbers.Init(logger)) connector.(*Connector).ImportDefinitions(airbyte.Init(logger, options.Airbyte)) connector.(*Connector).ImportDefinitions(bigquery.Init(logger)) diff --git a/pkg/stabilityai/image_to_image.go b/pkg/stabilityai/image_to_image.go index c791baa..df47cfc 100644 --- a/pkg/stabilityai/image_to_image.go +++ b/pkg/stabilityai/image_to_image.go @@ -112,7 +112,7 @@ func parseImageToImageReq(from *structpb.Struct) (ImageToImageReq, error) { func (req ImageToImageReq) getBytes() (b *bytes.Reader, contentType string, err error) { data := &bytes.Buffer{} - initImage, err := DecodeBase64(req.InitImage) + initImage, err := util.DecodeBase64(req.InitImage) if err != nil { return nil, "", err } diff --git a/pkg/stabilityai/main.go b/pkg/stabilityai/main.go index 7518634..60afec0 100644 --- a/pkg/stabilityai/main.go +++ b/pkg/stabilityai/main.go @@ -2,7 +2,6 @@ package stabilityai import ( _ "embed" - "encoding/base64" "fmt" "sync" @@ -155,8 +154,3 @@ func (c *Connector) Test(_ uuid.UUID, config *structpb.Struct, logger *zap.Logge return pipelinePB.Connector_STATE_CONNECTED, nil } - -// decode if the string is base64 encoded -func DecodeBase64(input string) ([]byte, error) { - return base64.StdEncoding.DecodeString(base.TrimBase64Mime(input)) -} diff --git a/pkg/util/helper.go b/pkg/util/helper.go index 3400a1a..a7c79ee 100644 --- a/pkg/util/helper.go +++ b/pkg/util/helper.go @@ -1,6 +1,7 @@ package util import ( + "encoding/base64" "mime/multipart" "net/http" "strings" @@ -8,6 +9,7 @@ import ( md "github.com/JohannesKaufmann/html-to-markdown" "github.com/PuerkitoBio/goquery" "github.com/h2non/filetype" + "github.com/instill-ai/component/pkg/base" ) func GetFileExt(fileData []byte) string { @@ -62,3 +64,9 @@ func ScrapeWebpageHTMLToMarkdown(html string) (string, error) { return markdown, nil } + +// DecodeBase64 takes a base64-encoded blob, trims the MIME type (if present) +// and decodes the remaining bytes. +func DecodeBase64(input string) ([]byte, error) { + return base64.StdEncoding.DecodeString(base.TrimBase64Mime(input)) +} diff --git a/pkg/util/helper_test.go b/pkg/util/helper_test.go new file mode 100644 index 0000000..4194451 --- /dev/null +++ b/pkg/util/helper_test.go @@ -0,0 +1,31 @@ +package util + +import ( + "testing" + + qt "github.com/frankban/quicktest" +) + +func TestDecodeBase46(t *testing.T) { + c := qt.New(t) + + c.Run("ok - with MIME prepended", func(c *qt.C) { + in := "data:text/plain;base64,aG9sYQ==" + got, err := DecodeBase64(in) + c.Check(err, qt.IsNil) + c.Check(got, qt.ContentEquals, []byte("hola")) + }) + + c.Run("ok - with MIME prepended", func(c *qt.C) { + in := "aG9sYQ==" + got, err := DecodeBase64(in) + c.Check(err, qt.IsNil) + c.Check(got, qt.ContentEquals, []byte("hola")) + }) + + c.Run("nok - invalid", func(c *qt.C) { + in := "hola==" + _, err := DecodeBase64(in) + c.Check(err, qt.IsNotNil) + }) +}