diff --git a/lakeview/build.gradle b/lakeview/build.gradle index fe50b6ed..b9a914fe 100644 --- a/lakeview/build.gradle +++ b/lakeview/build.gradle @@ -61,6 +61,11 @@ dependencies { exclude group: "com.google.protobuf", module: "protobuf-java-utils" } + implementation('com.azure:azure-storage-file-datalake:12.19.1') { + exclude group: 'com.azure', module: 'azure-core-http-netty' + } + implementation 'com.azure:azure-identity:1.10.0' + implementation 'org.springframework:spring-beans:5.3.23' testImplementation "org.mockito:mockito-core:${versions.mockito}" diff --git a/lakeview/src/main/java/ai/onehouse/RuntimeModule.java b/lakeview/src/main/java/ai/onehouse/RuntimeModule.java index f5b5bd84..960232b8 100644 --- a/lakeview/src/main/java/ai/onehouse/RuntimeModule.java +++ b/lakeview/src/main/java/ai/onehouse/RuntimeModule.java @@ -10,9 +10,11 @@ import ai.onehouse.config.ConfigProvider; import ai.onehouse.config.models.common.FileSystemConfiguration; import ai.onehouse.storage.AsyncStorageClient; +import ai.onehouse.storage.AzureAsyncStorageClient; import ai.onehouse.storage.GCSAsyncStorageClient; import ai.onehouse.storage.S3AsyncStorageClient; import ai.onehouse.storage.StorageUtils; +import ai.onehouse.storage.providers.AzureStorageClientProvider; import ai.onehouse.storage.providers.GcsClientProvider; import ai.onehouse.storage.providers.S3AsyncClientProvider; @@ -128,10 +130,13 @@ static AsyncStorageClient providesAsyncStorageClientForDiscovery( StorageUtils storageUtils, @TableDiscoveryS3ObjectStorageClient S3AsyncClientProvider s3AsyncClientProvider, GcsClientProvider gcsClientProvider, + AzureStorageClientProvider azureStorageClientProvider, ExecutorService executorService) { FileSystemConfiguration fileSystemConfiguration = config.getFileSystemConfiguration(); if (fileSystemConfiguration.getS3Config() != null) { return new S3AsyncStorageClient(s3AsyncClientProvider, storageUtils, executorService); + } else if (fileSystemConfiguration.getAzureConfig() != null) { + return new AzureAsyncStorageClient(azureStorageClientProvider, storageUtils, executorService); } else { return new GCSAsyncStorageClient(gcsClientProvider, storageUtils, executorService); } @@ -145,10 +150,13 @@ static AsyncStorageClient providesAsyncStorageClientForUpload( StorageUtils storageUtils, @TableMetadataUploadS3ObjectStorageClient S3AsyncClientProvider s3AsyncClientProvider, GcsClientProvider gcsClientProvider, + AzureStorageClientProvider azureStorageClientProvider, ExecutorService executorService) { FileSystemConfiguration fileSystemConfiguration = config.getFileSystemConfiguration(); if (fileSystemConfiguration.getS3Config() != null) { return new S3AsyncStorageClient(s3AsyncClientProvider, storageUtils, executorService); + } else if (fileSystemConfiguration.getAzureConfig() != null) { + return new AzureAsyncStorageClient(azureStorageClientProvider, storageUtils, executorService); } else { return new GCSAsyncStorageClient(gcsClientProvider, storageUtils, executorService); } diff --git a/lakeview/src/main/java/ai/onehouse/config/models/common/AzureConfig.java b/lakeview/src/main/java/ai/onehouse/config/models/common/AzureConfig.java new file mode 100644 index 00000000..4a5e4114 --- /dev/null +++ b/lakeview/src/main/java/ai/onehouse/config/models/common/AzureConfig.java @@ -0,0 +1,28 @@ +package ai.onehouse.config.models.common; + +import java.util.Optional; +import lombok.Builder; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.NonNull; +import lombok.extern.jackson.Jacksonized; + +@Builder +@Getter +@Jacksonized +@EqualsAndHashCode +public class AzureConfig { + @NonNull private String accountName; + + // Optional authentication methods + // Option 1: Account Key (for dev/testing, never expires) + @Builder.Default private Optional accountKey = Optional.empty(); + + // Option 2: Connection String (alternative to account key) + @Builder.Default private Optional connectionString = Optional.empty(); + + // Option 3: Service Principal + @Builder.Default private Optional tenantId = Optional.empty(); + @Builder.Default private Optional clientId = Optional.empty(); + @Builder.Default private Optional clientSecret = Optional.empty(); +} diff --git a/lakeview/src/main/java/ai/onehouse/config/models/common/FileSystemConfiguration.java b/lakeview/src/main/java/ai/onehouse/config/models/common/FileSystemConfiguration.java index b681385c..26799b48 100644 --- a/lakeview/src/main/java/ai/onehouse/config/models/common/FileSystemConfiguration.java +++ b/lakeview/src/main/java/ai/onehouse/config/models/common/FileSystemConfiguration.java @@ -12,4 +12,5 @@ public class FileSystemConfiguration { private S3Config s3Config; private GCSConfig gcsConfig; + private AzureConfig azureConfig; } diff --git a/lakeview/src/main/java/ai/onehouse/constants/StorageConstants.java b/lakeview/src/main/java/ai/onehouse/constants/StorageConstants.java index e703754f..33f373fe 100644 --- a/lakeview/src/main/java/ai/onehouse/constants/StorageConstants.java +++ b/lakeview/src/main/java/ai/onehouse/constants/StorageConstants.java @@ -5,10 +5,14 @@ public class StorageConstants { private StorageConstants() {} - // typical s3 path: "s3://bucket-name/path/to/object" - // gcs path format "gs:// [bucket] /path/to/file" + /* + * typical s3 path: "s3://bucket-name/path/to/object" + * gcs path format: "gs://bucket/path/to/file" + * azure blob format: "https://account.blob.core.windows.net/container/path/to/blob" + * azure adls gen2 format: "https://account.dfs.core.windows.net/container/path/to/file" + */ public static final Pattern OBJECT_STORAGE_URI_PATTERN = - Pattern.compile("^(s3://|gs://)([^/]+)(/.*)?"); + Pattern.compile("^(?:(s3://|gs://)|https://[^.]+\\.(?:blob|dfs)\\.core\\.windows\\.net/)([^/]+)(/.*)?$"); // https://cloud.google.com/compute/docs/naming-resources#resource-name-format public static final String GCP_RESOURCE_NAME_FORMAT = "^[a-z]([-a-z0-9]*[a-z0-9])$"; diff --git a/lakeview/src/main/java/ai/onehouse/storage/AzureAsyncStorageClient.java b/lakeview/src/main/java/ai/onehouse/storage/AzureAsyncStorageClient.java new file mode 100644 index 00000000..e56020e0 --- /dev/null +++ b/lakeview/src/main/java/ai/onehouse/storage/AzureAsyncStorageClient.java @@ -0,0 +1,233 @@ +package ai.onehouse.storage; + +import ai.onehouse.exceptions.AccessDeniedException; +import ai.onehouse.exceptions.NoSuchKeyException; +import ai.onehouse.exceptions.ObjectStorageClientException; +import ai.onehouse.exceptions.RateLimitException; +import ai.onehouse.storage.models.File; +import ai.onehouse.storage.models.FileStreamData; +import ai.onehouse.storage.providers.AzureStorageClientProvider; +import com.azure.core.http.rest.PagedFlux; +import com.azure.core.http.rest.PagedResponse; +import com.azure.core.util.BinaryData; +import com.azure.storage.file.datalake.DataLakeDirectoryAsyncClient; +import com.azure.storage.file.datalake.DataLakeFileAsyncClient; +import com.azure.storage.file.datalake.DataLakeFileSystemAsyncClient; +import com.azure.storage.file.datalake.DataLakeServiceAsyncClient; +import com.azure.storage.file.datalake.models.DataLakeRequestConditions; +import com.azure.storage.file.datalake.models.DataLakeStorageException; +import com.azure.storage.file.datalake.models.ListPathsOptions; +import com.azure.storage.file.datalake.models.PathItem; +import com.google.common.annotations.VisibleForTesting; +import com.google.inject.Inject; +import java.io.ByteArrayInputStream; +import java.time.Instant; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutorService; +import javax.annotation.Nonnull; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.StringUtils; +import org.apache.commons.lang3.tuple.Pair; + +@Slf4j +public class AzureAsyncStorageClient extends AbstractAsyncStorageClient { + private final AzureStorageClientProvider azureStorageClientProvider; + + @Inject + public AzureAsyncStorageClient( + @Nonnull AzureStorageClientProvider azureStorageClientProvider, + @Nonnull StorageUtils storageUtils, + @Nonnull ExecutorService executorService) { + super(executorService, storageUtils); + this.azureStorageClientProvider = azureStorageClientProvider; + } + + @Override + public CompletableFuture>> fetchObjectsByPage( + String containerName, String prefix, String continuationToken, String startAfter) { + + log.debug( + "fetching files in container {} with prefix {} continuationToken {} startAfter {}", + containerName, + prefix, + continuationToken, + startAfter); + + return CompletableFuture.supplyAsync( + () -> { + try { + DataLakeServiceAsyncClient dataLakeServiceClient = + azureStorageClientProvider.getAzureAsyncClient(); + DataLakeFileSystemAsyncClient fileSystemClient = + dataLakeServiceClient.getFileSystemAsyncClient(containerName); + + ListPathsOptions options = new ListPathsOptions(); + if (StringUtils.isNotBlank(prefix)) { + options.setPath(prefix); + } + + PagedFlux pagedFlux = fileSystemClient.listPaths(options); + + List files = new ArrayList<>(); + String nextContinuationToken = null; + + // Get single page with continuation token + try (PagedResponse page = + StringUtils.isNotBlank(continuationToken) + ? pagedFlux.byPage(continuationToken).blockFirst() + : pagedFlux.byPage().blockFirst()) { + + if (page != null) { + // Process items in the page + page.getElements() + .forEach( + pathItem -> { + String pathName = pathItem.getName(); + boolean isDirectory = pathItem.isDirectory(); + String fileName = pathName.replaceFirst("^" + prefix, ""); + + files.add( + File.builder() + .filename(fileName) + .lastModifiedAt( + isDirectory + ? Instant.EPOCH + : pathItem.getLastModified().toInstant()) + .isDirectory(isDirectory) + .build()); + }); + + // Get continuation token for next page + nextContinuationToken = page.getContinuationToken(); + } + } + + return Pair.of(nextContinuationToken, files); + } catch (Exception ex) { + return handleListPathsException(ex, containerName, prefix); + } + }, + executorService); + } + + @VisibleForTesting + CompletableFuture readBlob(String azureUri) { + log.debug("Reading Azure Data Lake file: {}", azureUri); + return CompletableFuture.supplyAsync( + () -> { + try { + DataLakeFileAsyncClient fileClient = getFileClient(azureUri); + return BinaryData.fromBytes(fileClient.read().blockLast().array()); + } catch (Exception ex) { + log.error("Failed to read file", ex); + throw clientException(ex, "readBlob", azureUri); + } + }, + executorService); + } + + @Override + public CompletableFuture streamFileAsync(String azureUri) { + return readBlob(azureUri) + .thenApply( + binaryData -> + FileStreamData.builder() + .inputStream(new ByteArrayInputStream(binaryData.toBytes())) + .fileSize((long) binaryData.toBytes().length) + .build()); + } + + @Override + public CompletableFuture readFileAsBytes(String azureUri) { + return readBlob(azureUri).thenApply(BinaryData::toBytes); + } + + private DataLakeFileAsyncClient getFileClient(String azureUri) { + String fileSystemName = storageUtils.getBucketNameFromUri(azureUri); + String filePath = storageUtils.getPathFromUrl(azureUri); + + DataLakeServiceAsyncClient dataLakeServiceClient = azureStorageClientProvider.getAzureAsyncClient(); + DataLakeFileSystemAsyncClient fileSystemClient = + dataLakeServiceClient.getFileSystemAsyncClient(fileSystemName); + return fileSystemClient.getFileAsyncClient(filePath); + } + + private Pair> handleListPathsException( + Exception ex, String containerName, String prefix) { + // DataLake API returns 404 for non-existent paths, treat as empty directory + Throwable wrappedException = ex.getCause() != null ? ex.getCause() : ex; + if (wrappedException instanceof DataLakeStorageException) { + DataLakeStorageException dlsException = (DataLakeStorageException) wrappedException; + if ("PathNotFound".equals(dlsException.getErrorCode()) + || dlsException.getStatusCode() == 404) { + log.debug( + "Path not found, returning empty list for container: {}, prefix: {}", + containerName, + prefix); + return Pair.of(null, new ArrayList<>()); + } + } + log.error("Failed to fetch objects by page", ex); + throw clientException(ex, "fetchObjectsByPage", containerName); + } + + @Override + protected RuntimeException clientException(Throwable ex, String operation, String path) { + Throwable wrappedException = ex.getCause() != null ? ex.getCause() : ex; + + if (wrappedException instanceof DataLakeStorageException) { + DataLakeStorageException dataLakeException = (DataLakeStorageException) wrappedException; + String errorCode = dataLakeException.getErrorCode(); + int statusCode = dataLakeException.getStatusCode(); + + log.error( + "Error in Azure Data Lake operation: {} on path: {} code: {} status: {} message: {}", + operation, + path, + errorCode, + statusCode, + dataLakeException.getMessage()); + + // Map to AccessDeniedException + if (statusCode == 403 || statusCode == 401) { + return new AccessDeniedException( + String.format( + "AccessDenied for operation: %s on path: %s with message: %s", + operation, path, dataLakeException.getMessage())); + } + + // Map to NoSuchKeyException + if (errorCode != null + && (errorCode.equals("PathNotFound") + || errorCode.equals("FilesystemNotFound") + || statusCode == 404)) { + return new NoSuchKeyException( + String.format("NoSuchKey for operation: %s on path: %s", operation, path)); + } + + // Map to RateLimitException + if (statusCode == 429 || statusCode == 503) { + return new RateLimitException( + String.format("Throttled by Azure for operation: %s on path: %s", operation, path)); + } + } else if (wrappedException instanceof AccessDeniedException + || wrappedException instanceof NoSuchKeyException + || wrappedException instanceof RateLimitException) { + return (RuntimeException) wrappedException; + } + + return new ObjectStorageClientException(ex); + } + + @Override + public void refreshClient() { + azureStorageClientProvider.refreshClient(); + } + + @Override + public void initializeClient() { + azureStorageClientProvider.getAzureAsyncClient(); + } +} diff --git a/lakeview/src/main/java/ai/onehouse/storage/StorageUtils.java b/lakeview/src/main/java/ai/onehouse/storage/StorageUtils.java index ccd91ba4..d6e488bc 100644 --- a/lakeview/src/main/java/ai/onehouse/storage/StorageUtils.java +++ b/lakeview/src/main/java/ai/onehouse/storage/StorageUtils.java @@ -1,5 +1,7 @@ package ai.onehouse.storage; +import org.apache.commons.lang3.StringUtils; + import static ai.onehouse.constants.StorageConstants.OBJECT_STORAGE_URI_PATTERN; import java.util.regex.Matcher; @@ -7,21 +9,27 @@ public class StorageUtils { private static final String INVALID_STORAGE_URI_ERROR_MSG = "Invalid Object storage Uri: "; + /** + * Group 3 extracts the path portion after the bucket/container name from the URI. + * Examples: + *
    + *
  • s3://my-bucket/path/to/file.txt returns /path/to/file.txt
  • + *
  • gs://my-bucket/path/to/file.txt returns /path/to/file.txt
  • + *
  • https://account.blob.core.windows.net/container/path/to/file.txt returns /path/to/file.txt
  • + *
  • https://account.dfs.core.windows.net/container/path/to/file.txt returns /path/to/file.txt
  • + *
+ * @param uri the storage URI to parse + * @return the path portion of the URI, or empty string if no path + */ public String getPathFromUrl(String uri) { - if (!OBJECT_STORAGE_URI_PATTERN.matcher(uri).matches()) { + Matcher matcher = OBJECT_STORAGE_URI_PATTERN.matcher(uri); + if (!matcher.matches()) { throw new IllegalArgumentException(INVALID_STORAGE_URI_ERROR_MSG + uri); } - String prefix = ""; - - // Remove the scheme and bucket name from the S3 path - int startIndex = uri.indexOf('/', 5); // Skip 's3://' and 'gs://' - if (startIndex != -1) { - prefix = uri.substring(startIndex + 1); - } - - return prefix; + String path = matcher.group(3); + return path == null ? StringUtils.EMPTY : path.replaceFirst("^/", ""); } public String constructFileUri(String directoryUri, String filePath) { @@ -33,6 +41,18 @@ public String constructFileUri(String directoryUri, String filePath) { filePath.startsWith("/") ? filePath.substring(1) : filePath); } + /** + * Group 2 extracts the bucket/container name from the URI. + * Examples: + *
    + *
  • s3://my-bucket-s3/path/to/file.txt returns my-bucket-s3
  • + *
  • gs://my-bucket-gs/path/to/file.txt returns my-bucket-gs
  • + *
  • https://account.blob.core.windows.net/container/path/file.txt returns container
  • + *
  • https://account.dfs.core.windows.net/container/path/file.txt returns container
  • + *
+ * @param uri the storage URI to parse + * @return the bucket or container name + */ public String getBucketNameFromUri(String uri) { Matcher matcher = OBJECT_STORAGE_URI_PATTERN.matcher(uri); if (matcher.matches()) { diff --git a/lakeview/src/main/java/ai/onehouse/storage/providers/AzureStorageClientProvider.java b/lakeview/src/main/java/ai/onehouse/storage/providers/AzureStorageClientProvider.java new file mode 100644 index 00000000..e0944797 --- /dev/null +++ b/lakeview/src/main/java/ai/onehouse/storage/providers/AzureStorageClientProvider.java @@ -0,0 +1,107 @@ +package ai.onehouse.storage.providers; + +import ai.onehouse.config.Config; +import ai.onehouse.config.models.common.AzureConfig; +import ai.onehouse.config.models.common.FileSystemConfiguration; +import com.azure.identity.ClientSecretCredential; +import com.azure.identity.ClientSecretCredentialBuilder; +import com.azure.identity.DefaultAzureCredential; +import com.azure.identity.DefaultAzureCredentialBuilder; +import com.azure.storage.file.datalake.DataLakeServiceAsyncClient; +import com.azure.storage.file.datalake.DataLakeServiceClientBuilder; +import com.azure.storage.common.StorageSharedKeyCredential; +import com.google.common.annotations.VisibleForTesting; +import com.google.inject.Inject; +import java.util.Optional; +import javax.annotation.Nonnull; +import org.apache.commons.lang3.StringUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class AzureStorageClientProvider { + private final AzureConfig azureConfig; + private static DataLakeServiceAsyncClient azureAsyncClient; + private static final Logger logger = LoggerFactory.getLogger(AzureStorageClientProvider.class); + + @Inject + public AzureStorageClientProvider(@Nonnull Config config) { + FileSystemConfiguration fileSystemConfiguration = config.getFileSystemConfiguration(); + this.azureConfig = fileSystemConfiguration.getAzureConfig(); + } + + @VisibleForTesting + protected DataLakeServiceAsyncClient createAzureAsyncClient() { + logger.debug("Instantiating Azure Data Lake Storage client"); + validateAzureConfig(azureConfig); + + DataLakeServiceClientBuilder builder = new DataLakeServiceClientBuilder(); + String endpoint = String.format("https://%s.dfs.core.windows.net", azureConfig.getAccountName()); + builder.endpoint(endpoint); + + // Option 1: Connection String (includes account key and endpoint) + Optional connectionStringOpt = azureConfig.getConnectionString(); + if (connectionStringOpt.isPresent()) { + logger.debug("Using connection string for authentication"); + builder.connectionString(connectionStringOpt.get()); + return builder.buildAsyncClient(); + } + + // Option 2: Account Key (shared key credential) + Optional accountKeyOpt = azureConfig.getAccountKey(); + if (accountKeyOpt.isPresent()) { + logger.debug("Using account key for authentication"); + StorageSharedKeyCredential credential = + new StorageSharedKeyCredential(azureConfig.getAccountName(), accountKeyOpt.get()); + builder.credential(credential); + return builder.buildAsyncClient(); + } + + // Option 3: Service Principal (client secret credential) + Optional tenantIdOpt = azureConfig.getTenantId(); + Optional clientIdOpt = azureConfig.getClientId(); + Optional clientSecretOpt = azureConfig.getClientSecret(); + if (tenantIdOpt.isPresent() && clientIdOpt.isPresent() && clientSecretOpt.isPresent()) { + logger.debug("Using service principal (client secret) for authentication"); + ClientSecretCredential credential = + new ClientSecretCredentialBuilder() + .tenantId(tenantIdOpt.get()) + .clientId(clientIdOpt.get()) + .clientSecret(clientSecretOpt.get()) + .build(); + builder.credential(credential); + return builder.buildAsyncClient(); + } + + // Option 4: Default Azure Credential (fallback to environment-based auth) + logger.debug("Using default Azure credential chain for authentication"); + DefaultAzureCredential credential = new DefaultAzureCredentialBuilder().build(); + builder.credential(credential); + return builder.buildAsyncClient(); + } + + public DataLakeServiceAsyncClient getAzureAsyncClient() { + if (azureAsyncClient == null) { + azureAsyncClient = createAzureAsyncClient(); + } + return azureAsyncClient; + } + + public void refreshClient() { + azureAsyncClient = createAzureAsyncClient(); + } + + private void validateAzureConfig(AzureConfig azureConfig) { + if (azureConfig == null) { + throw new IllegalArgumentException("Azure Config not found"); + } + + if (StringUtils.isBlank(azureConfig.getAccountName())) { + throw new IllegalArgumentException("Azure storage account name cannot be empty"); + } + } + + @VisibleForTesting + static void resetAzureAsyncClient() { + azureAsyncClient = null; + } +} diff --git a/lakeview/src/test/java/ai/onehouse/TestRuntimeModule.java b/lakeview/src/test/java/ai/onehouse/TestRuntimeModule.java index bb5aa2a6..2c1f5b04 100644 --- a/lakeview/src/test/java/ai/onehouse/TestRuntimeModule.java +++ b/lakeview/src/test/java/ai/onehouse/TestRuntimeModule.java @@ -10,6 +10,7 @@ import ai.onehouse.api.AsyncHttpClientWithRetry; import ai.onehouse.config.Config; +import ai.onehouse.config.models.common.AzureConfig; import ai.onehouse.config.models.common.FileSystemConfiguration; import ai.onehouse.config.models.common.GCSConfig; import ai.onehouse.config.models.common.S3Config; @@ -19,10 +20,12 @@ import ai.onehouse.metadata_extractor.TableDiscoveryService; import ai.onehouse.metadata_extractor.TimelineCommitInstantsUploader; import ai.onehouse.storage.AsyncStorageClient; +import ai.onehouse.storage.AzureAsyncStorageClient; import ai.onehouse.storage.GCSAsyncStorageClient; import ai.onehouse.storage.PresignedUrlFileUploader; import ai.onehouse.storage.S3AsyncStorageClient; import ai.onehouse.storage.StorageUtils; +import ai.onehouse.storage.providers.AzureStorageClientProvider; import ai.onehouse.storage.providers.GcsClientProvider; import ai.onehouse.storage.providers.S3AsyncClientProvider; import com.google.inject.AbstractModule; @@ -109,6 +112,7 @@ void testProvidesAsyncStorageClient(FileSystem fileSystemType) { FileSystemConfiguration mockFileSystemConfiguration = mock(FileSystemConfiguration.class); S3AsyncClientProvider mockS3AsyncClientProvider = mock(S3AsyncClientProvider.class); GcsClientProvider mockGcsClientProvider = mock(GcsClientProvider.class); + AzureStorageClientProvider mockAzureStorageClientProvider = mock(AzureStorageClientProvider.class); when(mockConfig.getFileSystemConfiguration()).thenReturn(mockFileSystemConfiguration); @@ -116,6 +120,10 @@ void testProvidesAsyncStorageClient(FileSystem fileSystemType) { S3Config mockS3Config = mock(S3Config.class); when(mockFileSystemConfiguration.getS3Config()).thenReturn(mockS3Config); when(mockS3AsyncClientProvider.getS3AsyncClient()).thenReturn(null); + } else if (FileSystem.AZURE.equals(fileSystemType)) { + AzureConfig mockAzureConfig = mock(AzureConfig.class); + when(mockFileSystemConfiguration.getS3Config()).thenReturn(null); + when(mockFileSystemConfiguration.getAzureConfig()).thenReturn(mockAzureConfig); } else { GCSConfig mockGcsConfig = mock(GCSConfig.class); when(mockFileSystemConfiguration.getGcsConfig()).thenReturn(mockGcsConfig); @@ -128,9 +136,12 @@ void testProvidesAsyncStorageClient(FileSystem fileSystemType) { mockStorageUtils, mockS3AsyncClientProvider, mockGcsClientProvider, + mockAzureStorageClientProvider, mockExecutorService); if (FileSystem.S3.equals(fileSystemType)) { assertInstanceOf(S3AsyncStorageClient.class, asyncStorageClientForDiscovery); + } else if (FileSystem.AZURE.equals(fileSystemType)) { + assertInstanceOf(AzureAsyncStorageClient.class, asyncStorageClientForDiscovery); } else { assertInstanceOf(GCSAsyncStorageClient.class, asyncStorageClientForDiscovery); } @@ -141,9 +152,12 @@ void testProvidesAsyncStorageClient(FileSystem fileSystemType) { mockStorageUtils, mockS3AsyncClientProvider, mockGcsClientProvider, + mockAzureStorageClientProvider, mockExecutorService); if (FileSystem.S3.equals(fileSystemType)) { Assertions.assertInstanceOf(S3AsyncStorageClient.class, asyncStorageClientForUpload); + } else if (FileSystem.AZURE.equals(fileSystemType)) { + Assertions.assertInstanceOf(AzureAsyncStorageClient.class, asyncStorageClientForUpload); } else { Assertions.assertInstanceOf(GCSAsyncStorageClient.class, asyncStorageClientForUpload); } @@ -172,6 +186,7 @@ static class GuiceTestModule extends AbstractModule { private final StorageUtils storageUtils; private final S3AsyncClientProvider s3Provider; private final GcsClientProvider gcsProvider; + private final AzureStorageClientProvider azureProvider; private final ExecutorService executorService; private final Metrics metrics; private final LakeViewExtractorMetrics lakeViewExtractorMetrics; @@ -179,13 +194,15 @@ static class GuiceTestModule extends AbstractModule { private final OnehouseApiClient onehouseApiClient; GuiceTestModule(Config config, StorageUtils storageUtils, S3AsyncClientProvider s3Provider, - GcsClientProvider gcsProvider, ExecutorService executorService, - Metrics metrics, LakeViewExtractorMetrics lakeViewExtractorMetrics, + GcsClientProvider gcsProvider, AzureStorageClientProvider azureProvider, + ExecutorService executorService, Metrics metrics, + LakeViewExtractorMetrics lakeViewExtractorMetrics, AsyncHttpClientWithRetry httpClient, OnehouseApiClient onehouseApiClient) { this.config = config; this.storageUtils = storageUtils; this.s3Provider = s3Provider; this.gcsProvider = gcsProvider; + this.azureProvider = azureProvider; this.executorService = executorService; this.metrics = metrics; this.lakeViewExtractorMetrics = lakeViewExtractorMetrics; @@ -200,6 +217,7 @@ protected void configure() { bind(StorageUtils.class).toInstance(storageUtils); bind(S3AsyncClientProvider.class).toInstance(s3Provider); bind(GcsClientProvider.class).toInstance(gcsProvider); + bind(AzureStorageClientProvider.class).toInstance(azureProvider); bind(ExecutorService.class).toInstance(executorService); bind(Metrics.class).toInstance(metrics); bind(LakeViewExtractorMetrics.class).toInstance(lakeViewExtractorMetrics); @@ -208,7 +226,7 @@ protected void configure() { } @Test - void testGuiceBootstrapping_S3_and_GCS() { + void testGuiceBootstrapping_S3_Azure_and_GCS() { // S3 setup FileSystemConfiguration mockFsConfig = mock(FileSystemConfiguration.class); S3Config mockS3Config = mock(S3Config.class); @@ -222,6 +240,7 @@ void testGuiceBootstrapping_S3_and_GCS() { mock(StorageUtils.class), mock(S3AsyncClientProvider.class), mock(GcsClientProvider.class), + mock(AzureStorageClientProvider.class), mock(ExecutorService.class), mock(Metrics.class), mock(LakeViewExtractorMetrics.class), @@ -236,6 +255,34 @@ void testGuiceBootstrapping_S3_and_GCS() { Key.get(AsyncStorageClient.class, RuntimeModule.TableDiscoveryObjectStorageAsyncClient.class)); Assertions.assertInstanceOf(S3AsyncStorageClient.class, s3ClientDiscovery); + // Azure setup + FileSystemConfiguration mockFsConfigAzure = mock(FileSystemConfiguration.class); + AzureConfig mockAzureConfig = mock(AzureConfig.class); + when(mockConfig.getFileSystemConfiguration()).thenReturn(mockFsConfigAzure); + when(mockFsConfigAzure.getS3Config()).thenReturn(null); + when(mockFsConfigAzure.getAzureConfig()).thenReturn(mockAzureConfig); + + Injector injectorAzure = Guice.createInjector( + Modules.override(new RuntimeModule(mockConfig)) + .with(new GuiceTestModule( + mockConfig, + mock(StorageUtils.class), + mock(S3AsyncClientProvider.class), + mock(GcsClientProvider.class), + mock(AzureStorageClientProvider.class), + mock(ExecutorService.class), + mock(Metrics.class), + mock(LakeViewExtractorMetrics.class), + mock(AsyncHttpClientWithRetry.class), + mock(OnehouseApiClient.class))) + ); + + AsyncStorageClient azureClientUpload = injectorAzure.getInstance( + Key.get(AsyncStorageClient.class, RuntimeModule.TableMetadataUploadObjectStorageAsyncClient.class)); + Assertions.assertInstanceOf(AzureAsyncStorageClient.class, azureClientUpload); + AsyncStorageClient azureClientDiscovery = injectorAzure.getInstance( + Key.get(AsyncStorageClient.class, RuntimeModule.TableDiscoveryObjectStorageAsyncClient.class)); + Assertions.assertInstanceOf(AzureAsyncStorageClient.class, azureClientDiscovery); // GCS setup FileSystemConfiguration mockFsConfigGcs = mock(FileSystemConfiguration.class); @@ -252,8 +299,8 @@ void testGuiceBootstrapping_S3_and_GCS() { Modules.override(new RuntimeModule(mockConfig)) .with(new GuiceTestModule( mockConfig, mock(StorageUtils.class), mock(S3AsyncClientProvider.class), - mock(GcsClientProvider.class), mock(ExecutorService.class), - mockMetrics, mockLakeViewExtractorMetrics, + mock(GcsClientProvider.class), mock(AzureStorageClientProvider.class), + mock(ExecutorService.class), mockMetrics, mockLakeViewExtractorMetrics, mockHttpClient, mockOnehouseApiClient)) ); @@ -274,6 +321,12 @@ mockConfig, mock(StorageUtils.class), mock(S3AsyncClientProvider.class), assertNotNull(injectorS3.getInstance(TableDiscoveryService.class)); assertNotNull(injectorS3.getInstance(TimelineCommitInstantsUploader.class)); assertNotNull(injectorS3.getInstance(PresignedUrlFileUploader.class)); + assertNotNull(injectorAzure.getInstance(AzureStorageClientProvider.class)); + assertNotNull(injectorAzure.getInstance(HoodiePropertiesReader.class)); + assertNotNull(injectorAzure.getInstance(TableDiscoveryAndUploadJob.class)); + assertNotNull(injectorAzure.getInstance(TableDiscoveryService.class)); + assertNotNull(injectorAzure.getInstance(TimelineCommitInstantsUploader.class)); + assertNotNull(injectorAzure.getInstance(PresignedUrlFileUploader.class)); assertNotNull(injectorGcs.getInstance(GcsClientProvider.class)); assertNotNull(injectorGcs.getInstance(HoodiePropertiesReader.class)); assertNotNull(injectorGcs.getInstance(TableDiscoveryAndUploadJob.class)); @@ -284,6 +337,7 @@ mockConfig, mock(StorageUtils.class), mock(S3AsyncClientProvider.class), enum FileSystem { S3, + AZURE, GCS } } diff --git a/lakeview/src/test/java/ai/onehouse/storage/AzureAsyncStorageClientTest.java b/lakeview/src/test/java/ai/onehouse/storage/AzureAsyncStorageClientTest.java new file mode 100644 index 00000000..8f71d225 --- /dev/null +++ b/lakeview/src/test/java/ai/onehouse/storage/AzureAsyncStorageClientTest.java @@ -0,0 +1,400 @@ +package ai.onehouse.storage; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.*; + +import ai.onehouse.exceptions.AccessDeniedException; +import ai.onehouse.exceptions.NoSuchKeyException; +import ai.onehouse.exceptions.ObjectStorageClientException; +import ai.onehouse.exceptions.RateLimitException; +import ai.onehouse.storage.models.File; +import ai.onehouse.storage.models.FileStreamData; +import ai.onehouse.storage.providers.AzureStorageClientProvider; +import com.azure.core.http.rest.PagedFlux; +import com.azure.core.http.rest.PagedResponse; +import com.azure.core.util.BinaryData; +import com.azure.storage.file.datalake.DataLakeFileAsyncClient; +import com.azure.storage.file.datalake.DataLakeFileSystemAsyncClient; +import com.azure.storage.file.datalake.DataLakeServiceAsyncClient; +import com.azure.storage.file.datalake.models.DataLakeStorageException; +import com.azure.core.util.IterableStream; +import com.azure.storage.file.datalake.models.PathItem; +import java.nio.ByteBuffer; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.nio.charset.StandardCharsets; +import java.time.Instant; +import java.time.OffsetDateTime; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ForkJoinPool; +import java.util.stream.Stream; +import org.apache.commons.lang3.tuple.Pair; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +@ExtendWith(MockitoExtension.class) +class AzureAsyncStorageClientTest { + + @Mock private AzureStorageClientProvider mockAzureStorageClientProvider; + @Mock private StorageUtils mockStorageUtils; + @Mock private DataLakeServiceAsyncClient mockDataLakeServiceAsyncClient; + @Mock private DataLakeFileSystemAsyncClient mockFileSystemAsyncClient; + @Mock private DataLakeFileAsyncClient mockFileAsyncClient; + @Mock private PagedFlux mockPagedFlux; + @Mock private PagedResponse mockPagedResponse1; + @Mock private PagedResponse mockPagedResponse2; + @Mock private PathItem mockPathItem1; + @Mock private PathItem mockPathItem2; + + private AzureAsyncStorageClient azureAsyncStorageClient; + private static final String AZURE_URI = + "https://testaccount.dfs.core.windows.net/test-container/test-file"; + private static final String TEST_CONTAINER = "test-container"; + private static final String TEST_FILE = "test-file"; + + @BeforeEach + void setup() { + lenient() + .when(mockAzureStorageClientProvider.getAzureAsyncClient()) + .thenReturn(mockDataLakeServiceAsyncClient); + lenient().when(mockStorageUtils.getBucketNameFromUri(AZURE_URI)).thenReturn(TEST_CONTAINER); + lenient().when(mockStorageUtils.getPathFromUrl(AZURE_URI)).thenReturn(TEST_FILE); + azureAsyncStorageClient = + new AzureAsyncStorageClient( + mockAzureStorageClientProvider, mockStorageUtils, ForkJoinPool.commonPool()); + } + + @Test + void testListAllFilesInDir() throws ExecutionException, InterruptedException { + String fileName = "file1"; + String dirName = "dir1/"; + String continuationToken = "page_2"; + String prefix = TEST_FILE + "/"; + + when(mockDataLakeServiceAsyncClient.getFileSystemAsyncClient(TEST_CONTAINER)) + .thenReturn(mockFileSystemAsyncClient); + when(mockFileSystemAsyncClient.listPaths(any())) + .thenReturn(mockPagedFlux); + when(mockPagedFlux.byPage()).thenReturn(Flux.just(mockPagedResponse1)); + when(mockPagedFlux.byPage(continuationToken)).thenReturn(Flux.just(mockPagedResponse2)); + + // First page + when(mockPagedResponse1.getElements()) + .thenReturn(IterableStream.of(Arrays.asList(mockPathItem1))); + when(mockPagedResponse1.getContinuationToken()).thenReturn(continuationToken); + when(mockPathItem1.getName()).thenReturn(prefix + fileName); + when(mockPathItem1.isDirectory()).thenReturn(false); + when(mockPathItem1.getLastModified()).thenReturn(OffsetDateTime.now()); + + // Second page + when(mockPagedResponse2.getElements()) + .thenReturn(IterableStream.of(Arrays.asList(mockPathItem2))); + when(mockPagedResponse2.getContinuationToken()).thenReturn(null); + when(mockPathItem2.getName()).thenReturn(prefix + dirName); + when(mockPathItem2.isDirectory()).thenReturn(true); + + List result = azureAsyncStorageClient.listAllFilesInDir(AZURE_URI).get(); + + assertEquals(2, result.size()); + assertFalse(result.get(0).isDirectory()); + assertEquals(fileName, result.get(0).getFilename()); + assertTrue(result.get(1).isDirectory()); + assertEquals(dirName, result.get(1).getFilename()); + } + + @Test + void testFetchObjectsByPage() throws ExecutionException, InterruptedException { + String fileName = "file1"; + String prefix = "prefix"; + String continuationToken = "token"; + String nextToken = "next-token"; + + when(mockDataLakeServiceAsyncClient.getFileSystemAsyncClient(TEST_CONTAINER)) + .thenReturn(mockFileSystemAsyncClient); + when(mockFileSystemAsyncClient.listPaths(any())) + .thenReturn(mockPagedFlux); + when(mockPagedFlux.byPage(continuationToken)) + .thenReturn(Flux.just(mockPagedResponse1)); + + when(mockPagedResponse1.getElements()) + .thenReturn(IterableStream.of(Arrays.asList(mockPathItem1))); + when(mockPagedResponse1.getContinuationToken()).thenReturn(nextToken); + when(mockPathItem1.getName()).thenReturn(prefix + "/" + fileName); + when(mockPathItem1.isDirectory()).thenReturn(false); + when(mockPathItem1.getLastModified()).thenReturn(OffsetDateTime.now()); + + Pair> result = + azureAsyncStorageClient + .fetchObjectsByPage(TEST_CONTAINER, prefix, continuationToken, null) + .get(); + + assertEquals(nextToken, result.getLeft()); + assertEquals(1, result.getRight().size()); + assertEquals("/" + fileName, result.getRight().get(0).getFilename()); + } + + @Test + void testFetchObjectsByPageWithoutContinuationToken() + throws ExecutionException, InterruptedException { + String fileName = "file1"; + String prefix = "prefix"; + + when(mockDataLakeServiceAsyncClient.getFileSystemAsyncClient(TEST_CONTAINER)) + .thenReturn(mockFileSystemAsyncClient); + when(mockFileSystemAsyncClient.listPaths(any())) + .thenReturn(mockPagedFlux); + when(mockPagedFlux.byPage()).thenReturn(Flux.just(mockPagedResponse1)); + + when(mockPagedResponse1.getElements()) + .thenReturn(IterableStream.of(Arrays.asList(mockPathItem1))); + when(mockPagedResponse1.getContinuationToken()).thenReturn(null); + when(mockPathItem1.getName()).thenReturn(prefix + "/" + fileName); + when(mockPathItem1.isDirectory()).thenReturn(false); + when(mockPathItem1.getLastModified()).thenReturn(OffsetDateTime.now()); + + Pair> result = + azureAsyncStorageClient.fetchObjectsByPage(TEST_CONTAINER, prefix, null, null).get(); + + assertNull(result.getLeft()); + assertEquals(1, result.getRight().size()); + } + + @Test + void testReadBlob() throws ExecutionException, InterruptedException { + byte[] fileContent = "test content".getBytes(StandardCharsets.UTF_8); + ByteBuffer byteBuffer = ByteBuffer.wrap(fileContent); + + when(mockDataLakeServiceAsyncClient.getFileSystemAsyncClient(TEST_CONTAINER)) + .thenReturn(mockFileSystemAsyncClient); + when(mockFileSystemAsyncClient.getFileAsyncClient(TEST_FILE)).thenReturn(mockFileAsyncClient); + when(mockFileAsyncClient.read()).thenReturn(Flux.just(byteBuffer)); + + BinaryData result = azureAsyncStorageClient.readBlob(AZURE_URI).get(); + + assertNotNull(result); + assertArrayEquals(fileContent, result.toBytes()); + } + + @Test + void testStreamFileAsync() throws ExecutionException, InterruptedException, IOException { + byte[] fileContent = "test content".getBytes(StandardCharsets.UTF_8); + ByteBuffer byteBuffer = ByteBuffer.wrap(fileContent); + + when(mockDataLakeServiceAsyncClient.getFileSystemAsyncClient(TEST_CONTAINER)) + .thenReturn(mockFileSystemAsyncClient); + when(mockFileSystemAsyncClient.getFileAsyncClient(TEST_FILE)).thenReturn(mockFileAsyncClient); + when(mockFileAsyncClient.read()).thenReturn(Flux.just(byteBuffer)); + + FileStreamData result = azureAsyncStorageClient.streamFileAsync(AZURE_URI).get(); + + assertNotNull(result); + assertEquals(fileContent.length, result.getFileSize()); + + byte[] resultContent = toByteArray(result.getInputStream()); + assertArrayEquals(fileContent, resultContent); + } + + @Test + void testReadFileAsBytes() throws ExecutionException, InterruptedException { + byte[] fileContent = "test content".getBytes(StandardCharsets.UTF_8); + ByteBuffer byteBuffer = ByteBuffer.wrap(fileContent); + + when(mockDataLakeServiceAsyncClient.getFileSystemAsyncClient(TEST_CONTAINER)) + .thenReturn(mockFileSystemAsyncClient); + when(mockFileSystemAsyncClient.getFileAsyncClient(TEST_FILE)).thenReturn(mockFileAsyncClient); + when(mockFileAsyncClient.read()).thenReturn(Flux.just(byteBuffer)); + + byte[] result = azureAsyncStorageClient.readFileAsBytes(AZURE_URI).get(); + + assertArrayEquals(fileContent, result); + } + + @ParameterizedTest + @MethodSource("generateDataLakeStorageExceptionTestCases") + void testReadBlobWithDataLakeStorageException( + DataLakeStorageException exception, Class expectedExceptionClass) { + when(mockDataLakeServiceAsyncClient.getFileSystemAsyncClient(TEST_CONTAINER)) + .thenReturn(mockFileSystemAsyncClient); + when(mockFileSystemAsyncClient.getFileAsyncClient(TEST_FILE)).thenReturn(mockFileAsyncClient); + when(mockFileAsyncClient.read()).thenAnswer(invocation -> Flux.error(exception)); + + CompletableFuture future = azureAsyncStorageClient.readBlob(AZURE_URI); + CompletionException executionException = assertThrows(CompletionException.class, future::join); + + Throwable cause = executionException.getCause(); + assertInstanceOf(expectedExceptionClass, cause); + } + + static Stream generateDataLakeStorageExceptionTestCases() { + DataLakeStorageException accessDenied403 = mock(DataLakeStorageException.class); + when(accessDenied403.getStatusCode()).thenReturn(403); + when(accessDenied403.getMessage()).thenReturn("Access denied"); + when(accessDenied403.getErrorCode()).thenReturn(null); + + DataLakeStorageException unauthorized401 = mock(DataLakeStorageException.class); + when(unauthorized401.getStatusCode()).thenReturn(401); + when(unauthorized401.getMessage()).thenReturn("Unauthorized"); + when(unauthorized401.getErrorCode()).thenReturn(null); + + DataLakeStorageException pathNotFound = mock(DataLakeStorageException.class); + when(pathNotFound.getStatusCode()).thenReturn(404); + when(pathNotFound.getErrorCode()).thenReturn("PathNotFound"); + when(pathNotFound.getMessage()).thenReturn("Path not found"); + + DataLakeStorageException filesystemNotFound = mock(DataLakeStorageException.class); + when(filesystemNotFound.getStatusCode()).thenReturn(404); + when(filesystemNotFound.getErrorCode()).thenReturn("FilesystemNotFound"); + when(filesystemNotFound.getMessage()).thenReturn("Filesystem not found"); + + DataLakeStorageException tooManyRequests = mock(DataLakeStorageException.class); + when(tooManyRequests.getStatusCode()).thenReturn(429); + when(tooManyRequests.getMessage()).thenReturn("Too many requests"); + when(tooManyRequests.getErrorCode()).thenReturn(null); + + DataLakeStorageException serviceUnavailable = mock(DataLakeStorageException.class); + when(serviceUnavailable.getStatusCode()).thenReturn(503); + when(serviceUnavailable.getMessage()).thenReturn("Service unavailable"); + when(serviceUnavailable.getErrorCode()).thenReturn(null); + + DataLakeStorageException internalError = mock(DataLakeStorageException.class); + when(internalError.getStatusCode()).thenReturn(500); + when(internalError.getMessage()).thenReturn("Internal error"); + when(internalError.getErrorCode()).thenReturn(null); + + return Stream.of( + Arguments.of(accessDenied403, AccessDeniedException.class), + Arguments.of(unauthorized401, AccessDeniedException.class), + Arguments.of(pathNotFound, NoSuchKeyException.class), + Arguments.of(filesystemNotFound, NoSuchKeyException.class), + Arguments.of(tooManyRequests, RateLimitException.class), + Arguments.of(serviceUnavailable, RateLimitException.class), + Arguments.of(internalError, ObjectStorageClientException.class)); + } + + @ParameterizedTest + @MethodSource("generateWrappedExceptionTestCases") + void testReadBlobWithWrappedException( + RuntimeException exception, Class expectedExceptionClass) { + when(mockDataLakeServiceAsyncClient.getFileSystemAsyncClient(TEST_CONTAINER)) + .thenReturn(mockFileSystemAsyncClient); + when(mockFileSystemAsyncClient.getFileAsyncClient(TEST_FILE)).thenReturn(mockFileAsyncClient); + when(mockFileAsyncClient.read()).thenAnswer(invocation -> Flux.error(exception)); + + CompletableFuture future = azureAsyncStorageClient.readBlob(AZURE_URI); + CompletionException executionException = assertThrows(CompletionException.class, future::join); + + Throwable cause = executionException.getCause(); + assertInstanceOf(expectedExceptionClass, cause); + } + + static Stream generateWrappedExceptionTestCases() { + return Stream.of( + Arguments.of(new AccessDeniedException("error"), AccessDeniedException.class), + Arguments.of(new NoSuchKeyException("error"), NoSuchKeyException.class), + Arguments.of(new RateLimitException("error"), RateLimitException.class)); + } + + @Test + void testFetchObjectsByPageWithException() { + DataLakeStorageException exception = mock(DataLakeStorageException.class); + when(exception.getStatusCode()).thenReturn(429); + when(exception.getMessage()).thenReturn("Throttled"); + when(exception.getErrorCode()).thenReturn(null); + + when(mockDataLakeServiceAsyncClient.getFileSystemAsyncClient(TEST_CONTAINER)) + .thenReturn(mockFileSystemAsyncClient); + when(mockFileSystemAsyncClient.listPaths(any())) + .thenThrow(exception); + + CompletableFuture>> future = + azureAsyncStorageClient.fetchObjectsByPage(TEST_CONTAINER, "prefix", null, null); + CompletionException executionException = assertThrows(CompletionException.class, future::join); + + Throwable cause = executionException.getCause(); + assertInstanceOf(RateLimitException.class, cause); + assertTrue( + cause.getMessage().contains("Throttled by Azure for operation: fetchObjectsByPage")); + } + + @Test + void testRefreshClient() { + azureAsyncStorageClient.refreshClient(); + verify(mockAzureStorageClientProvider, times(1)).refreshClient(); + } + + @Test + void testInitializeClient() { + azureAsyncStorageClient.initializeClient(); + verify(mockAzureStorageClientProvider, times(1)).getAzureAsyncClient(); + } + + @Test + void testReadBlobWithGenericException() { + RuntimeException exception = new RuntimeException("Generic error"); + + when(mockDataLakeServiceAsyncClient.getFileSystemAsyncClient(TEST_CONTAINER)) + .thenReturn(mockFileSystemAsyncClient); + when(mockFileSystemAsyncClient.getFileAsyncClient(TEST_FILE)).thenReturn(mockFileAsyncClient); + when(mockFileAsyncClient.read()).thenAnswer(invocation -> Flux.error(exception)); + + CompletableFuture future = azureAsyncStorageClient.readBlob(AZURE_URI); + CompletionException executionException = assertThrows(CompletionException.class, future::join); + + Throwable cause = executionException.getCause(); + assertInstanceOf(ObjectStorageClientException.class, cause); + } + + @Test + void testFetchObjectsByPageWithDirectoryItems() throws ExecutionException, InterruptedException { + String dirName = "dir1/"; + String prefix = "prefix"; + + when(mockDataLakeServiceAsyncClient.getFileSystemAsyncClient(TEST_CONTAINER)) + .thenReturn(mockFileSystemAsyncClient); + when(mockFileSystemAsyncClient.listPaths(any())) + .thenReturn(mockPagedFlux); + when(mockPagedFlux.byPage()).thenReturn(Flux.just(mockPagedResponse1)); + + when(mockPagedResponse1.getElements()) + .thenReturn(IterableStream.of(Arrays.asList(mockPathItem1))); + when(mockPagedResponse1.getContinuationToken()).thenReturn(null); + when(mockPathItem1.getName()).thenReturn(prefix + dirName); + when(mockPathItem1.isDirectory()).thenReturn(true); + + Pair> result = + azureAsyncStorageClient.fetchObjectsByPage(TEST_CONTAINER, prefix, null, null).get(); + + assertEquals(1, result.getRight().size()); + assertTrue(result.getRight().get(0).isDirectory()); + assertEquals(dirName, result.getRight().get(0).getFilename()); + assertEquals(Instant.EPOCH, result.getRight().get(0).getLastModifiedAt()); + } + + private static byte[] toByteArray(InputStream is) throws IOException { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream()) { + byte[] buffer = new byte[8192]; + int bytesRead; + while ((bytesRead = is.read(buffer)) != -1) { + baos.write(buffer, 0, bytesRead); + } + return baos.toByteArray(); + } + } + +} diff --git a/lakeview/src/test/java/ai/onehouse/storage/StorageUtilsTest.java b/lakeview/src/test/java/ai/onehouse/storage/StorageUtilsTest.java index 4fb326c3..bcfad568 100644 --- a/lakeview/src/test/java/ai/onehouse/storage/StorageUtilsTest.java +++ b/lakeview/src/test/java/ai/onehouse/storage/StorageUtilsTest.java @@ -11,8 +11,20 @@ class StorageUtilsTest { void testGetPathFromUrl() { assertEquals("path/to/file", storageUtils.getPathFromUrl("s3://bucket/path/to/file")); assertEquals("path/to/file", storageUtils.getPathFromUrl("gs://bucket/path/to/file")); + assertEquals( + "path/to/file", + storageUtils.getPathFromUrl( + "https://account.blob.core.windows.net/container/path/to/file")); + assertEquals( + "path/to/file", + storageUtils.getPathFromUrl( + "https://account.dfs.core.windows.net/container/path/to/file")); assertEquals("", storageUtils.getPathFromUrl("s3://bucket")); assertEquals("", storageUtils.getPathFromUrl("gs://bucket")); + assertEquals( + "", storageUtils.getPathFromUrl("https://account.blob.core.windows.net/container")); + assertEquals( + "", storageUtils.getPathFromUrl("https://account.dfs.core.windows.net/container")); assertThrows(IllegalArgumentException.class, () -> storageUtils.getPathFromUrl("invalidUri")); } @@ -20,21 +32,45 @@ void testGetPathFromUrl() { void testConstructFileUri() { String s3DirUriWithoutTrailingSlash = "s3://bucket/dir1"; String s3DirUriWithTrailingSlash = "s3://bucket/dir1/"; + String azureDirUriWithoutTrailingSlash = + "https://account.blob.core.windows.net/container/dir1"; + String azureDirUriWithTrailingSlash = + "https://account.blob.core.windows.net/container/dir1/"; String filePathWithoutPrefixSlash = "file.txt"; String filePathWithPrefixSlash = "/file.txt"; - String expectedFileUri = s3DirUriWithTrailingSlash + filePathWithoutPrefixSlash; + String expectedS3FileUri = s3DirUriWithTrailingSlash + filePathWithoutPrefixSlash; + String expectedAzureFileUri = azureDirUriWithTrailingSlash + filePathWithoutPrefixSlash; + + // S3 tests assertEquals( - expectedFileUri, + expectedS3FileUri, storageUtils.constructFileUri(s3DirUriWithoutTrailingSlash, filePathWithoutPrefixSlash)); assertEquals( - expectedFileUri, + expectedS3FileUri, storageUtils.constructFileUri(s3DirUriWithTrailingSlash, filePathWithoutPrefixSlash)); assertEquals( - expectedFileUri, + expectedS3FileUri, storageUtils.constructFileUri(s3DirUriWithoutTrailingSlash, filePathWithPrefixSlash)); assertEquals( - expectedFileUri, + expectedS3FileUri, storageUtils.constructFileUri(s3DirUriWithTrailingSlash, filePathWithPrefixSlash)); + + // Azure tests + assertEquals( + expectedAzureFileUri, + storageUtils.constructFileUri( + azureDirUriWithoutTrailingSlash, filePathWithoutPrefixSlash)); + assertEquals( + expectedAzureFileUri, + storageUtils.constructFileUri(azureDirUriWithTrailingSlash, filePathWithoutPrefixSlash)); + assertEquals( + expectedAzureFileUri, + storageUtils.constructFileUri(azureDirUriWithoutTrailingSlash, filePathWithPrefixSlash)); + assertEquals( + expectedAzureFileUri, + storageUtils.constructFileUri(azureDirUriWithTrailingSlash, filePathWithPrefixSlash)); + + // Edge cases assertEquals( filePathWithPrefixSlash, storageUtils.constructFileUri("", filePathWithoutPrefixSlash)); assertEquals( @@ -49,6 +85,22 @@ void testConstructFileUri() { void testGetBucketNameFromUri() { assertEquals("bucket", storageUtils.getBucketNameFromUri("s3://bucket/path/to/file")); assertEquals("bucket", storageUtils.getBucketNameFromUri("gs://bucket/path/to/file")); + assertEquals( + "container", + storageUtils.getBucketNameFromUri( + "https://account.blob.core.windows.net/container/path/to/file")); + assertEquals( + "container", + storageUtils.getBucketNameFromUri( + "https://account.dfs.core.windows.net/container/path/to/file")); + assertEquals("bucket", storageUtils.getBucketNameFromUri("s3://bucket")); + assertEquals("bucket", storageUtils.getBucketNameFromUri("gs://bucket")); + assertEquals( + "container", + storageUtils.getBucketNameFromUri("https://account.blob.core.windows.net/container")); + assertEquals( + "container", + storageUtils.getBucketNameFromUri("https://account.dfs.core.windows.net/container")); assertThrows( IllegalArgumentException.class, () -> storageUtils.getBucketNameFromUri("invalidUri")); } diff --git a/lakeview/src/test/java/ai/onehouse/storage/providers/AzureStorageClientProviderTest.java b/lakeview/src/test/java/ai/onehouse/storage/providers/AzureStorageClientProviderTest.java new file mode 100644 index 00000000..dc878264 --- /dev/null +++ b/lakeview/src/test/java/ai/onehouse/storage/providers/AzureStorageClientProviderTest.java @@ -0,0 +1,253 @@ +package ai.onehouse.storage.providers; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.*; + +import ai.onehouse.config.models.common.AzureConfig; +import ai.onehouse.config.models.common.FileSystemConfiguration; +import ai.onehouse.config.models.configv1.ConfigV1; +import com.azure.identity.ClientSecretCredential; +import com.azure.identity.ClientSecretCredentialBuilder; +import com.azure.identity.DefaultAzureCredential; +import com.azure.identity.DefaultAzureCredentialBuilder; +import com.azure.storage.file.datalake.DataLakeServiceAsyncClient; +import com.azure.storage.file.datalake.DataLakeServiceClientBuilder; +import com.azure.storage.common.StorageSharedKeyCredential; +import java.util.Optional; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.mockito.Mock; +import org.mockito.MockedConstruction; +import org.mockito.MockedStatic; +import org.mockito.Mockito; +import org.mockito.junit.jupiter.MockitoExtension; + +@ExtendWith(MockitoExtension.class) +class AzureStorageClientProviderTest { + @Mock private ConfigV1 config; + @Mock private FileSystemConfiguration fileSystemConfiguration; + @Mock private AzureConfig azureConfig; + @Mock private DataLakeServiceAsyncClient mockDataLakeServiceAsyncClient; + + @BeforeEach + void setup() { + when(config.getFileSystemConfiguration()).thenReturn(fileSystemConfiguration); + when(fileSystemConfiguration.getAzureConfig()).thenReturn(azureConfig); + } + + @Test + void throwExceptionWhenAzureConfigIsNull() { + when(fileSystemConfiguration.getAzureConfig()).thenReturn(null); + AzureStorageClientProvider clientProvider = new AzureStorageClientProvider(config); + + IllegalArgumentException thrown = + assertThrows(IllegalArgumentException.class, clientProvider::createAzureAsyncClient); + + assertEquals("Azure Config not found", thrown.getMessage()); + } + + @Test + void throwExceptionWhenAccountNameIsBlank() { + when(azureConfig.getAccountName()).thenReturn(""); + + AzureStorageClientProvider clientProvider = new AzureStorageClientProvider(config); + IllegalArgumentException thrown = + assertThrows(IllegalArgumentException.class, clientProvider::createAzureAsyncClient); + + assertEquals("Azure storage account name cannot be empty", thrown.getMessage()); + } + + @ParameterizedTest + @ValueSource(booleans = {false, true}) + void testInstantiateAzureClientWithConnectionString(boolean isRefreshClient) { + when(azureConfig.getAccountName()).thenReturn("testaccount"); + when(azureConfig.getConnectionString()) + .thenReturn( + Optional.of( + "DefaultEndpointsProtocol=https;AccountName=testaccount;AccountKey=key;EndpointSuffix=core.windows.net")); + + try (MockedConstruction mockedBuilder = + mockConstruction( + DataLakeServiceClientBuilder.class, + (mock, context) -> { + when(mock.endpoint(anyString())).thenReturn(mock); + when(mock.connectionString(anyString())).thenReturn(mock); + when(mock.buildAsyncClient()).thenReturn(mockDataLakeServiceAsyncClient); + })) { + + AzureStorageClientProvider azureClientProviderSpy = + Mockito.spy(new AzureStorageClientProvider(config)); + AzureStorageClientProvider.resetAzureAsyncClient(); + + if (!isRefreshClient) { + DataLakeServiceAsyncClient result = azureClientProviderSpy.getAzureAsyncClient(); + assertNotNull(result); + } else { + azureClientProviderSpy.refreshClient(); + } + + verify(azureClientProviderSpy, times(1)).createAzureAsyncClient(); + } + } + + @ParameterizedTest + @ValueSource(booleans = {false, true}) + void testInstantiateAzureClientWithAccountKey(boolean isRefreshClient) { + when(azureConfig.getAccountName()).thenReturn("testaccount"); + when(azureConfig.getConnectionString()).thenReturn(Optional.empty()); + when(azureConfig.getAccountKey()).thenReturn(Optional.of("dGVzdGFjY291bnRrZXk=")); + + try (MockedConstruction mockedBuilder = + mockConstruction( + DataLakeServiceClientBuilder.class, + (mock, context) -> { + when(mock.endpoint(anyString())).thenReturn(mock); + when(mock.credential(any(StorageSharedKeyCredential.class))).thenReturn(mock); + when(mock.buildAsyncClient()).thenReturn(mockDataLakeServiceAsyncClient); + }); + MockedConstruction mockedCredential = + mockConstruction(StorageSharedKeyCredential.class)) { + + AzureStorageClientProvider azureClientProviderSpy = + Mockito.spy(new AzureStorageClientProvider(config)); + AzureStorageClientProvider.resetAzureAsyncClient(); + + if (!isRefreshClient) { + DataLakeServiceAsyncClient result = azureClientProviderSpy.getAzureAsyncClient(); + assertNotNull(result); + } else { + azureClientProviderSpy.refreshClient(); + } + + verify(azureClientProviderSpy, times(1)).createAzureAsyncClient(); + } + } + + @ParameterizedTest + @ValueSource(booleans = {false, true}) + void testInstantiateAzureClientWithServicePrincipal(boolean isRefreshClient) { + when(azureConfig.getAccountName()).thenReturn("testaccount"); + when(azureConfig.getConnectionString()).thenReturn(Optional.empty()); + when(azureConfig.getAccountKey()).thenReturn(Optional.empty()); + when(azureConfig.getTenantId()).thenReturn(Optional.of("test-tenant-id")); + when(azureConfig.getClientId()).thenReturn(Optional.of("test-client-id")); + when(azureConfig.getClientSecret()).thenReturn(Optional.of("test-client-secret")); + + try (MockedConstruction mockedBuilder = + mockConstruction( + DataLakeServiceClientBuilder.class, + (mock, context) -> { + when(mock.endpoint(anyString())).thenReturn(mock); + when(mock.credential(any(ClientSecretCredential.class))).thenReturn(mock); + when(mock.buildAsyncClient()).thenReturn(mockDataLakeServiceAsyncClient); + }); + MockedConstruction mockedCredBuilder = + mockConstruction( + ClientSecretCredentialBuilder.class, + (mock, context) -> { + when(mock.tenantId(anyString())).thenReturn(mock); + when(mock.clientId(anyString())).thenReturn(mock); + when(mock.clientSecret(anyString())).thenReturn(mock); + when(mock.build()).thenReturn(Mockito.mock(ClientSecretCredential.class)); + })) { + + AzureStorageClientProvider azureClientProviderSpy = + Mockito.spy(new AzureStorageClientProvider(config)); + AzureStorageClientProvider.resetAzureAsyncClient(); + + if (!isRefreshClient) { + DataLakeServiceAsyncClient result = azureClientProviderSpy.getAzureAsyncClient(); + assertNotNull(result); + } else { + azureClientProviderSpy.refreshClient(); + } + + verify(azureClientProviderSpy, times(1)).createAzureAsyncClient(); + } + } + + @ParameterizedTest + @ValueSource(booleans = {false, true}) + void testInstantiateAzureClientWithManagedIdentity(boolean isRefreshClient) { + when(azureConfig.getAccountName()).thenReturn("testaccount"); + when(azureConfig.getConnectionString()).thenReturn(Optional.empty()); + when(azureConfig.getAccountKey()).thenReturn(Optional.empty()); + when(azureConfig.getTenantId()).thenReturn(Optional.of("test-tenant-id")); + when(azureConfig.getClientId()).thenReturn(Optional.of("test-client-id")); + when(azureConfig.getClientSecret()).thenReturn(Optional.empty()); + + try (MockedConstruction mockedBuilder = + mockConstruction( + DataLakeServiceClientBuilder.class, + (mock, context) -> { + when(mock.endpoint(anyString())).thenReturn(mock); + when(mock.credential(any(DefaultAzureCredential.class))).thenReturn(mock); + when(mock.buildAsyncClient()).thenReturn(mockDataLakeServiceAsyncClient); + }); + MockedConstruction mockedCredBuilder = + mockConstruction( + DefaultAzureCredentialBuilder.class, + (mock, context) -> { + when(mock.tenantId(anyString())).thenReturn(mock); + when(mock.managedIdentityClientId(anyString())).thenReturn(mock); + when(mock.build()).thenReturn(Mockito.mock(DefaultAzureCredential.class)); + })) { + + AzureStorageClientProvider azureClientProviderSpy = + Mockito.spy(new AzureStorageClientProvider(config)); + AzureStorageClientProvider.resetAzureAsyncClient(); + + if (!isRefreshClient) { + DataLakeServiceAsyncClient result = azureClientProviderSpy.getAzureAsyncClient(); + assertNotNull(result); + } else { + azureClientProviderSpy.refreshClient(); + } + + verify(azureClientProviderSpy, times(1)).createAzureAsyncClient(); + } + } + + @ParameterizedTest + @ValueSource(booleans = {false, true}) + void testInstantiateAzureClientWithDefaultCredential(boolean isRefreshClient) { + when(azureConfig.getAccountName()).thenReturn("testaccount"); + when(azureConfig.getConnectionString()).thenReturn(Optional.empty()); + when(azureConfig.getAccountKey()).thenReturn(Optional.empty()); + when(azureConfig.getTenantId()).thenReturn(Optional.empty()); + when(azureConfig.getClientId()).thenReturn(Optional.empty()); + + try (MockedConstruction mockedBuilder = + mockConstruction( + DataLakeServiceClientBuilder.class, + (mock, context) -> { + when(mock.endpoint(anyString())).thenReturn(mock); + when(mock.credential(any(DefaultAzureCredential.class))).thenReturn(mock); + when(mock.buildAsyncClient()).thenReturn(mockDataLakeServiceAsyncClient); + }); + MockedConstruction mockedCredBuilder = + mockConstruction( + DefaultAzureCredentialBuilder.class, + (mock, context) -> { + when(mock.build()).thenReturn(Mockito.mock(DefaultAzureCredential.class)); + })) { + + AzureStorageClientProvider azureClientProviderSpy = + Mockito.spy(new AzureStorageClientProvider(config)); + AzureStorageClientProvider.resetAzureAsyncClient(); + + if (!isRefreshClient) { + DataLakeServiceAsyncClient result = azureClientProviderSpy.getAzureAsyncClient(); + assertNotNull(result); + } else { + azureClientProviderSpy.refreshClient(); + } + + verify(azureClientProviderSpy, times(1)).createAzureAsyncClient(); + } + } +}