From 5f216cecb1cf62040272d0ec8341e772c688071a Mon Sep 17 00:00:00 2001 From: Sebastien Flory Date: Fri, 31 Jan 2025 19:31:29 +0100 Subject: [PATCH] Add: location support when connecting bigquery --- connectors/src/connectors/bigquery/index.ts | 15 +- .../connectors/bigquery/lib/bigquery_api.ts | 29 ++- .../connectors/bigquery/lib/permissions.ts | 6 +- .../bigquery/temporal/activities.ts | 4 +- .../bigquery/temporal/cast_known_errors.ts | 83 +------- core/src/oauth/credential.rs | 1 + .../src/oauth/tests/functional_credentials.rs | 7 +- .../CreateOrUpdateConnectionBigQueryModal.tsx | 91 +++++++- front/lib/api/oauth.ts | 33 --- front/lib/swr/bigquery.ts | 42 ++++ front/package-lock.json | 157 +++++++++++++- front/package.json | 1 + .../check_bigquery_locations.test.ts | 197 ++++++++++++++++++ .../credentials/check_bigquery_locations.ts | 134 ++++++++++++ front/pages/api/w/[wId]/credentials/index.ts | 16 +- types/src/oauth/lib.ts | 37 +++- types/src/oauth/oauth_api.ts | 1 + 17 files changed, 698 insertions(+), 156 deletions(-) create mode 100644 front/lib/swr/bigquery.ts create mode 100644 front/pages/api/w/[wId]/credentials/check_bigquery_locations.test.ts create mode 100644 front/pages/api/w/[wId]/credentials/check_bigquery_locations.ts diff --git a/connectors/src/connectors/bigquery/index.ts b/connectors/src/connectors/bigquery/index.ts index c5d35c349265..e8ef5331e645 100644 --- a/connectors/src/connectors/bigquery/index.ts +++ b/connectors/src/connectors/bigquery/index.ts @@ -4,7 +4,12 @@ import type { ContentNodesViewType, Result, } from "@dust-tt/types"; -import { assertNever, Err, isBigQueryCredentials, Ok } from "@dust-tt/types"; +import { + assertNever, + Err, + isBigQueryWithLocationCredentials, + Ok, +} from "@dust-tt/types"; import type { TestConnectionError } from "@connectors/connectors/bigquery/lib/bigquery_api"; import { testConnection } from "@connectors/connectors/bigquery/lib/bigquery_api"; @@ -68,7 +73,7 @@ export class BigQueryConnectorManager extends BaseConnectorManager { }): Promise>> { const credentialsRes = await getCredentials({ credentialsId: connectionId, - isTypeGuard: isBigQueryCredentials, + isTypeGuard: isBigQueryWithLocationCredentials, logger, }); if (credentialsRes.isErr()) { @@ -125,7 +130,7 @@ export class BigQueryConnectorManager extends BaseConnectorManager { const newCredentialsRes = await getCredentials({ credentialsId: connectionId, - isTypeGuard: isBigQueryCredentials, + isTypeGuard: isBigQueryWithLocationCredentials, logger, }); if (newCredentialsRes.isErr()) { @@ -250,7 +255,7 @@ export class BigQueryConnectorManager extends BaseConnectorManager { > { const connectorAndCredentialsRes = await getConnectorAndCredentials({ connectorId: this.connectorId, - isTypeGuard: isBigQueryCredentials, + isTypeGuard: isBigQueryWithLocationCredentials, logger, }); if (connectorAndCredentialsRes.isErr()) { @@ -319,7 +324,7 @@ export class BigQueryConnectorManager extends BaseConnectorManager { }): Promise> { const connectorAndCredentialsRes = await getConnectorAndCredentials({ connectorId: this.connectorId, - isTypeGuard: isBigQueryCredentials, + isTypeGuard: isBigQueryWithLocationCredentials, logger, }); if (connectorAndCredentialsRes.isErr()) { diff --git a/connectors/src/connectors/bigquery/lib/bigquery_api.ts b/connectors/src/connectors/bigquery/lib/bigquery_api.ts index 22208cdb3eb2..a4434629581d 100644 --- a/connectors/src/connectors/bigquery/lib/bigquery_api.ts +++ b/connectors/src/connectors/bigquery/lib/bigquery_api.ts @@ -1,4 +1,4 @@ -import type { BigQueryCredentials, Result } from "@dust-tt/types"; +import type { BigQueryCredentialsWithLocation, Result } from "@dust-tt/types"; import { Err, Ok, removeNulls } from "@dust-tt/types"; import { BigQuery } from "@google-cloud/bigquery"; @@ -34,7 +34,7 @@ export function isTestConnectionError( export const testConnection = async ({ credentials, }: { - credentials: BigQueryCredentials; + credentials: BigQueryCredentialsWithLocation; }): Promise> => { // Connect to bigquery, do a simple query. const bigQuery = connectToBigQuery(credentials); @@ -53,17 +53,20 @@ export const testConnection = async ({ } }; -export function connectToBigQuery(credentials: BigQueryCredentials): BigQuery { +export function connectToBigQuery( + credentials: BigQueryCredentialsWithLocation +): BigQuery { return new BigQuery({ credentials, scopes: ["https://www.googleapis.com/auth/bigquery.readonly"], + location: credentials.location, }); } export const fetchDatabases = ({ credentials, }: { - credentials: BigQueryCredentials; + credentials: BigQueryCredentialsWithLocation; }): RemoteDBDatabase[] => { // BigQuery do not have a concept of databases per say, the most similar concept is a project. // Since credentials are always scoped to a project, we directly return a single database with the project name. @@ -79,7 +82,7 @@ export const fetchDatasets = async ({ credentials, connection, }: { - credentials: BigQueryCredentials; + credentials: BigQueryCredentialsWithLocation; connection?: BigQuery; }): Promise, Error>> => { const conn = connection ?? connectToBigQuery(credentials); @@ -89,6 +92,14 @@ export const fetchDatasets = async ({ return new Ok( removeNulls( datasets.map((dataset) => { + // We want to filter out datasets that are not in the same location as the credentials. + // But, for example, we want to keep dataset in "us-central1" when selected location is "us" + if ( + !dataset.location?.toLowerCase().startsWith(credentials.location) + ) { + return null; + } + if (!dataset.id) { return null; } @@ -113,7 +124,7 @@ export const fetchTables = async ({ internalDatasetId, connection, }: { - credentials: BigQueryCredentials; + credentials: BigQueryCredentialsWithLocation; datasetName?: string; internalDatasetId?: string; connection?: BigQuery; @@ -147,6 +158,12 @@ export const fetchTables = async ({ return new Ok( removeNulls( tables.map((table) => { + // We want to filter out tables that are not in the same location as the credentials. + // But, for example, we want to keep tables in "us-central1" when selected location is "us" + if (!table.location?.toLowerCase().startsWith(credentials.location)) { + return null; + } + if (!table.id) { return null; } diff --git a/connectors/src/connectors/bigquery/lib/permissions.ts b/connectors/src/connectors/bigquery/lib/permissions.ts index d4e237261094..c39adb07bce0 100644 --- a/connectors/src/connectors/bigquery/lib/permissions.ts +++ b/connectors/src/connectors/bigquery/lib/permissions.ts @@ -1,5 +1,5 @@ import type { - BigQueryCredentials, + BigQueryCredentialsWithLocation, ContentNode, ModelId, Result, @@ -34,7 +34,7 @@ export const fetchAvailableChildrenInBigQuery = async ({ parentInternalId, }: { connectorId: ModelId; - credentials: BigQueryCredentials; + credentials: BigQueryCredentialsWithLocation; parentInternalId: string | null; }): Promise> => { if (parentInternalId === null) { @@ -283,7 +283,7 @@ export const saveNodesFromPermissions = async ({ }: { permissions: Record; connectorId: ModelId; - credentials: BigQueryCredentials; + credentials: BigQueryCredentialsWithLocation; logger: Logger; }): Promise> => { for (const [internalId, permission] of Object.entries(permissions)) { diff --git a/connectors/src/connectors/bigquery/temporal/activities.ts b/connectors/src/connectors/bigquery/temporal/activities.ts index 0e77543e16f5..b99c2170ec79 100644 --- a/connectors/src/connectors/bigquery/temporal/activities.ts +++ b/connectors/src/connectors/bigquery/temporal/activities.ts @@ -1,5 +1,5 @@ import type { ModelId } from "@dust-tt/types"; -import { isBigQueryCredentials, MIME_TYPES } from "@dust-tt/types"; +import { isBigQueryWithLocationCredentials, MIME_TYPES } from "@dust-tt/types"; import { connectToBigQuery, @@ -30,7 +30,7 @@ import logger from "@connectors/logger/logger"; export async function syncBigQueryConnection(connectorId: ModelId) { const getConnectorAndCredentialsRes = await getConnectorAndCredentials({ connectorId, - isTypeGuard: isBigQueryCredentials, + isTypeGuard: isBigQueryWithLocationCredentials, logger, }); if (getConnectorAndCredentialsRes.isErr()) { diff --git a/connectors/src/connectors/bigquery/temporal/cast_known_errors.ts b/connectors/src/connectors/bigquery/temporal/cast_known_errors.ts index a57d7b18796e..add87f0fa3c0 100644 --- a/connectors/src/connectors/bigquery/temporal/cast_known_errors.ts +++ b/connectors/src/connectors/bigquery/temporal/cast_known_errors.ts @@ -4,73 +4,6 @@ import type { Next, } from "@temporalio/worker"; -import { ExternalOAuthTokenError } from "@connectors/lib/error"; - -interface BigQueryError extends Error { - name: string; - data: { - nextAction: string; - }; -} - -interface BigQueryExpiredPasswordError extends BigQueryError { - name: "OperationFailedError"; - data: { - nextAction: "PWD_CHANGE"; - }; -} - -interface BigQueryAccountLockedError extends BigQueryError { - name: "OperationFailedError"; - data: { - nextAction: "RETRY_LOGIN"; - }; -} - -interface BigQueryIncorrectCredentialsError extends BigQueryError { - name: "OperationFailedError"; - data: { - nextAction: "RETRY_LOGIN"; - }; -} - -function isBigQueryError(err: unknown): err is BigQueryError { - return ( - err instanceof Error && - "name" in err && - "data" in err && - typeof err.data === "object" && - err.data !== null && - "nextAction" in err.data && - typeof err.data.nextAction === "string" - ); -} - -function isBigQueryExpiredPasswordError( - err: unknown -): err is BigQueryExpiredPasswordError { - return isBigQueryError(err) && err.data.nextAction === "PWD_CHANGE"; -} - -function isBigQueryAccountLockedError( - err: unknown -): err is BigQueryAccountLockedError { - return ( - isBigQueryError(err) && - err.message.startsWith( - "Your user account has been temporarily locked due to too many failed attempts" - ) - ); -} - -function isBigQueryIncorrectCredentialsError( - err: unknown -): err is BigQueryIncorrectCredentialsError { - return ( - isBigQueryError(err) && - err.message.startsWith("Incorrect username or password was specified") - ); -} export class BigQueryCastKnownErrorsInterceptor implements ActivityInboundCallsInterceptor { @@ -78,19 +11,7 @@ export class BigQueryCastKnownErrorsInterceptor input: ActivityExecuteInput, next: Next ): Promise { - try { - return await next(input); - } catch (err: unknown) { - if ( - isBigQueryExpiredPasswordError(err) || - // technically, the one below could be transient; - // we add it here to make the user aware that getting locked out of his account blocks the connection - isBigQueryAccountLockedError(err) || - isBigQueryIncorrectCredentialsError(err) - ) { - throw new ExternalOAuthTokenError(err); - } - throw err; - } + // Will add custom error handling as we discover them + return next(input); } } diff --git a/core/src/oauth/credential.rs b/core/src/oauth/credential.rs index b646fb2b5403..233e8e74078e 100644 --- a/core/src/oauth/credential.rs +++ b/core/src/oauth/credential.rs @@ -125,6 +125,7 @@ impl Credential { "auth_provider_x509_cert_url", "client_x509_cert_url", "universe_domain", + "location", ] } }; diff --git a/core/src/oauth/tests/functional_credentials.rs b/core/src/oauth/tests/functional_credentials.rs index 5b42a90d62c1..435c7db3c128 100644 --- a/core/src/oauth/tests/functional_credentials.rs +++ b/core/src/oauth/tests/functional_credentials.rs @@ -151,7 +151,8 @@ async fn test_oauth_credentials_bigquery_flow_ok() { "token_uri": "https://oauth2.googleapis.com/token", "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs", "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/test", - "universe_domain": "googleapis.com" + "universe_domain": "googleapis.com", + "location": "EU" } }); @@ -209,6 +210,7 @@ async fn test_oauth_credentials_bigquery_flow_ok() { content.get("client_email").unwrap(), "test@test-project.iam.gserviceaccount.com" ); + assert_eq!(content.get("region").unwrap(), "EU"); } #[tokio::test] @@ -231,7 +233,8 @@ async fn test_oauth_credentials_bigquery_delete_ok() { "token_uri": "https://oauth2.googleapis.com/token", "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs", "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/test", - "universe_domain": "googleapis.com" + "universe_domain": "googleapis.com", + "location": "US" } }); diff --git a/front/components/data_source/CreateOrUpdateConnectionBigQueryModal.tsx b/front/components/data_source/CreateOrUpdateConnectionBigQueryModal.tsx index f87029fd35d6..366d1b669ff9 100644 --- a/front/components/data_source/CreateOrUpdateConnectionBigQueryModal.tsx +++ b/front/components/data_source/CreateOrUpdateConnectionBigQueryModal.tsx @@ -2,19 +2,24 @@ import { BookOpenIcon, Button, CloudArrowLeftRightIcon, + InformationCircleIcon, Modal, Page, + RadioGroup, + RadioGroupChoice, TextArea, + Tooltip, } from "@dust-tt/sparkle"; import type { - BigQueryCredentials, + BigQueryCredentialsWithLocation, + CheckBigQueryCredentials, ConnectorProvider, ConnectorType, DataSourceType, WorkspaceType, } from "@dust-tt/types"; import { - BigQueryCredentialsSchema, + CheckBigQueryCredentialsSchema, isConnectorsAPIError, } from "@dust-tt/types"; import { isRight } from "fp-ts/lib/Either"; @@ -22,6 +27,7 @@ import { formatValidationErrors } from "io-ts-reporters"; import { useEffect, useMemo, useState } from "react"; import type { ConnectorProviderConfiguration } from "@app/lib/connector_providers"; +import { useBigQueryLocations } from "@app/lib/swr/bigquery"; import type { PostCredentialsBody } from "@app/pages/api/w/[wId]/credentials"; type CreateOrUpdateConnectionBigQueryModalProps = { @@ -63,8 +69,9 @@ export function CreateOrUpdateConnectionBigQueryModal({ } try { - const credentialsObject: BigQueryCredentials = JSON.parse(credentials); - const r = BigQueryCredentialsSchema.decode(credentialsObject); + const credentialsObject: CheckBigQueryCredentials = + JSON.parse(credentials); + const r = CheckBigQueryCredentialsSchema.decode(credentialsObject); if (isRight(r)) { const allFieldsHaveValue = Object.values(credentialsObject).every( (v) => v.length > 0 @@ -92,6 +99,23 @@ export function CreateOrUpdateConnectionBigQueryModal({ } }, [credentials]); + // Region picking + const [selectedLocation, setSelectedLocation] = useState(); + const { locations, isLocationsLoading } = useBigQueryLocations({ + owner, + credentials: credentialsState.credentials, + }); + + const needToSelectLocation = useMemo(() => { + return locations && Object.keys(locations).length > 1; + }, [locations]); + + useEffect(() => { + if (locations && Object.keys(locations).length === 1) { + setSelectedLocation(Object.keys(locations)[0]); + } + }, [locations]); + useEffect(() => { setError(credentialsState.errorMessage); }, [credentialsState.errorMessage]); @@ -124,7 +148,10 @@ export function CreateOrUpdateConnectionBigQueryModal({ }, body: JSON.stringify({ provider: "bigquery" as const, - credentials: JSON.parse(credentials) as BigQueryCredentials, + credentials: { + ...credentialsState.credentials, + location: selectedLocation, + } as BigQueryCredentialsWithLocation, } as PostCredentialsBody), } ); @@ -192,7 +219,10 @@ export function CreateOrUpdateConnectionBigQueryModal({ }, body: JSON.stringify({ provider: "bigquery", - credentials: credentialsState.credentials, + credentials: { + ...credentialsState.credentials, + location: selectedLocation, + } as BigQueryCredentialsWithLocation, }), }); @@ -297,6 +327,48 @@ export function CreateOrUpdateConnectionBigQueryModal({ /> + {needToSelectLocation && ( +
+
+ Select a location +
+ +
+ {Object.entries(locations).map(([location, tables]) => ( + + + This location contains {tables.length} tables + that can be connected :{" "} + + {tables.join(", ")} + + + } + trigger={ +
+ {location} - {tables.length} tables{" "} + +
+ } + /> + + } + /> + ))} +
+
+
+ )} +