Skip to content

Commit

Permalink
Connect to Astra via TLS w/o using SCB (#309)
Browse files Browse the repository at this point in the history
* Implemented connection to Astra via TLS options without SCB
* Replaced hardcoded sides(Origin/Target) with enums
* Cleaned up imports
* Apply suggestions from code review
---------
Co-authored-by: Madhavan <msmygit@users.noreply.github.com>
  • Loading branch information
pravinbhat authored Sep 18, 2024
1 parent d81cdfb commit 0bab9e9
Show file tree
Hide file tree
Showing 20 changed files with 219 additions and 39 deletions.
3 changes: 3 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
# Release Notes
## [4.4.0] - 2024-09-19
- Added property `spark.cdm.connect.origin.tls.isAstra` and `spark.cdm.connect.target.tls.isAstra` to allow connecting to Astra DB without using [SCB](https://docs.datastax.com/en/astra-db-serverless/drivers/secure-connect-bundle.html). This may be needed for enterprises that may find credentials packaged within SCB as a security risk. TLS properties can now be passed as params OR wrapper scripts (not included) could be used to pull sensitive credentials from a vault service in real-time & pass them to CDM.

## [4.3.10] - 2024-09-12
- Added property `spark.cdm.trackRun.runId` to support a custom unique identifier for the current run. This can be used by wrapper scripts to pass a known `runId` and then use it to query the `cdm_run_info` and `cdm_run_details` tables.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import com.datastax.cdm.cql.EnhancedSession;
import com.datastax.cdm.properties.IPropertyHelper;
import com.datastax.cdm.schema.CqlTable;
import com.datastax.oss.driver.api.core.cql.*;
import com.datastax.oss.driver.api.core.cql.PreparedStatement;

public class BaseCdmStatement {

Expand Down
1 change: 0 additions & 1 deletion src/main/java/com/datastax/cdm/data/CqlConversion.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import java.util.*;
import java.util.stream.Collectors;

import org.apache.hadoop.yarn.webapp.hamlet2.Hamlet;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down
71 changes: 71 additions & 0 deletions src/main/java/com/datastax/cdm/data/DataUtility.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,13 @@
*/
package com.datastax.cdm.data;

import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.util.*;
import java.util.zip.ZipEntry;
import java.util.zip.ZipOutputStream;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand All @@ -26,6 +32,7 @@

public class DataUtility {
public static final Logger logger = LoggerFactory.getLogger(CqlConversion.class);
protected static final String SCB_FILE_NAME = "_temp_cdm_scb_do_not_touch.zip";

public static boolean diff(Object obj1, Object obj2) {
if (obj1 == null && obj2 == null) {
Expand Down Expand Up @@ -143,4 +150,68 @@ public static String getMyClassMethodLine(Exception e) {

return "Unknown";
}

public static void deleteGeneratedSCB() {
File file = new File(PKFactory.Side.ORIGIN + SCB_FILE_NAME);
if (file.exists()) {
file.delete();
}
file = new File(PKFactory.Side.TARGET + SCB_FILE_NAME);
if (file.exists()) {
file.delete();
}
}

public static File generateSCB(String host, String port, String trustStorePassword, String trustStorePath,
String keyStorePassword, String keyStorePath, PKFactory.Side side) throws IOException {
FileOutputStream fileOutputStream = new FileOutputStream("config.json");
String scbJson = new StringBuilder("{\"host\": \"").append(host).append("\", \"port\": ").append(port)
.append(", \"keyStoreLocation\": \"./identity.jks\", \"keyStorePassword\": \"").append(keyStorePassword)
.append("\", \"trustStoreLocation\": \"./trustStore.jks\", \"trustStorePassword\": \"")
.append(trustStorePassword).append("\"}").toString();
fileOutputStream.write(scbJson.getBytes());
fileOutputStream.close();
File configFile = new File("config.json");
FilePathAndNewName configFileWithName = new FilePathAndNewName(configFile, "config.json");
FilePathAndNewName keyFileWithName = new FilePathAndNewName(new File(keyStorePath), "identity.jks");
FilePathAndNewName trustFileWithName = new FilePathAndNewName(new File(trustStorePath), "trustStore.jks");
File zipFile = zip(Arrays.asList(configFileWithName, keyFileWithName, trustFileWithName), side + SCB_FILE_NAME);
configFile.delete();

return zipFile;
}

private static File zip(List<FilePathAndNewName> files, String filename) {
File zipfile = new File(filename);
byte[] buf = new byte[1024];
try {
ZipOutputStream out = new ZipOutputStream(new FileOutputStream(zipfile));
for (int i = 0; i < files.size(); i++) {
out.putNextEntry(new ZipEntry(files.get(i).name));
FileInputStream in = new FileInputStream(files.get(i).file.getCanonicalPath());
int len;
while ((len = in.read(buf)) > 0) {
out.write(buf, 0, len);
}
out.closeEntry();
in.close();
}
out.close();

return zipfile;
} catch (IOException ex) {
logger.error("Unable to write out zip file: {}. Got exception: {}", filename, ex.getMessage());
}
return null;
}

static class FilePathAndNewName {
File file;
String name;

public FilePathAndNewName(File file, String name) {
this.file = file;
this.name = name;
}
}
}
2 changes: 2 additions & 0 deletions src/main/java/com/datastax/cdm/job/AbstractJobSession.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.slf4j.LoggerFactory;

import com.datastax.cdm.cql.EnhancedSession;
import com.datastax.cdm.data.DataUtility;
import com.datastax.cdm.data.PKFactory;
import com.datastax.cdm.feature.Feature;
import com.datastax.cdm.feature.Featureset;
Expand Down Expand Up @@ -114,6 +115,7 @@ public synchronized void initCdmRun(long runId, long prevRunId, Collection<Split
public synchronized void printCounts(boolean isFinal) {
if (isFinal) {
jobCounter.printFinal(runId, trackRunFeature);
DataUtility.deleteGeneratedSCB();
} else {
jobCounter.printProgress();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
import com.datastax.cdm.data.PKFactory;
import com.datastax.cdm.data.Record;
import com.datastax.oss.driver.api.core.CqlSession;
import com.datastax.oss.driver.api.core.cql.*;
import com.datastax.oss.driver.api.core.cql.ResultSet;
import com.datastax.oss.driver.api.core.cql.Row;

public class GuardrailCheckJobSession extends AbstractJobSession<SplitPartitions.Partition> {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@ public enum PropertyType {
public static final String ORIGIN_TLS_KEYSTORE_PATH = "spark.cdm.connect.origin.tls.keyStore.path";
public static final String ORIGIN_TLS_KEYSTORE_PASSWORD = "spark.cdm.connect.origin.tls.keyStore.password";
public static final String ORIGIN_TLS_ALGORITHMS = "spark.cdm.connect.origin.tls.enabledAlgorithms"; // TLS_RSA_WITH_AES_128_CBC_SHA,TLS_RSA_WITH_AES_256_CBC_SHA
public static final String ORIGIN_TLS_IS_ASTRA = "spark.cdm.connect.origin.tls.isAstra";
static {
types.put(ORIGIN_TLS_ENABLED, PropertyType.BOOLEAN);
defaults.put(ORIGIN_TLS_ENABLED, "false");
Expand All @@ -290,6 +291,8 @@ public enum PropertyType {
types.put(ORIGIN_TLS_KEYSTORE_PASSWORD, PropertyType.STRING);
types.put(ORIGIN_TLS_ALGORITHMS, PropertyType.STRING); // This is a list but it is handled by Spark
defaults.put(ORIGIN_TLS_ALGORITHMS, "TLS_RSA_WITH_AES_128_CBC_SHA,TLS_RSA_WITH_AES_256_CBC_SHA");
types.put(ORIGIN_TLS_IS_ASTRA, PropertyType.BOOLEAN);
defaults.put(ORIGIN_TLS_IS_ASTRA, "false");
}

// ==========================================================================
Expand All @@ -302,6 +305,7 @@ public enum PropertyType {
public static final String TARGET_TLS_KEYSTORE_PATH = "spark.cdm.connect.target.tls.keyStore.path";
public static final String TARGET_TLS_KEYSTORE_PASSWORD = "spark.cdm.connect.target.tls.keyStore.password";
public static final String TARGET_TLS_ALGORITHMS = "spark.cdm.connect.target.tls.enabledAlgorithms"; // TLS_RSA_WITH_AES_128_CBC_SHA,TLS_RSA_WITH_AES_256_CBC_SHA
public static final String TARGET_TLS_IS_ASTRA = "spark.cdm.connect.target.tls.isAstra";
static {
types.put(TARGET_TLS_ENABLED, PropertyType.BOOLEAN);
defaults.put(TARGET_TLS_ENABLED, "false");
Expand All @@ -313,6 +317,8 @@ public enum PropertyType {
types.put(TARGET_TLS_KEYSTORE_PASSWORD, PropertyType.STRING);
types.put(TARGET_TLS_ALGORITHMS, PropertyType.STRING); // This is a list but it is handled by Spark
defaults.put(TARGET_TLS_ALGORITHMS, "TLS_RSA_WITH_AES_128_CBC_SHA,TLS_RSA_WITH_AES_256_CBC_SHA");
types.put(TARGET_TLS_IS_ASTRA, PropertyType.BOOLEAN);
defaults.put(TARGET_TLS_IS_ASTRA, "false");
}

// ==========================================================================
Expand Down
5 changes: 3 additions & 2 deletions src/main/scala/com/datastax/cdm/job/BaseJob.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession
import org.slf4j.LoggerFactory
import com.datastax.cdm.data.PKFactory.Side

import java.math.BigInteger
import java.util
Expand Down Expand Up @@ -70,8 +71,8 @@ abstract class BaseJob[T: ClassTag] extends App {

consistencyLevel = propertyHelper.getString(KnownProperties.READ_CL)
val connectionFetcher = new ConnectionFetcher(sContext, propertyHelper)
originConnection = connectionFetcher.getConnection("ORIGIN", consistencyLevel)
targetConnection = connectionFetcher.getConnection("TARGET", consistencyLevel)
originConnection = connectionFetcher.getConnection(Side.ORIGIN, consistencyLevel)
targetConnection = connectionFetcher.getConnection(Side.TARGET, consistencyLevel)

val hasRandomPartitioner: Boolean = {
val partitionerName = originConnection.withSessionDo(_.getMetadata.getTokenMap.get().getPartitionerName)
Expand Down
3 changes: 2 additions & 1 deletion src/main/scala/com/datastax/cdm/job/ConnectionDetails.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,6 @@ case class ConnectionDetails(
trustStoreType: String,
keyStorePath: String,
keyStorePassword: String,
enabledAlgorithms: String
enabledAlgorithms: String,
isAstra: Boolean
)
33 changes: 24 additions & 9 deletions src/main/scala/com/datastax/cdm/job/ConnectionFetcher.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,19 @@
*/
package com.datastax.cdm.job

import com.datastax.cdm.properties.{KnownProperties, PropertyHelper}
import com.datastax.cdm.properties.{KnownProperties, IPropertyHelper}
import com.datastax.spark.connector.cql.CassandraConnector
import org.apache.spark.{SparkConf, SparkContext}
import org.slf4j.{Logger, LoggerFactory}
import com.datastax.cdm.data.DataUtility.generateSCB
import com.datastax.cdm.data.PKFactory.Side

// TODO: CDM-31 - add localDC configuration support
class ConnectionFetcher(sparkContext: SparkContext, propertyHelper: PropertyHelper) {
class ConnectionFetcher(sparkContext: SparkContext, propertyHelper: IPropertyHelper) {
val logger: Logger = LoggerFactory.getLogger(this.getClass.getName)

def getConnectionDetails(side: String): ConnectionDetails = {
if ("ORIGIN".equals(side.toUpperCase)) {
def getConnectionDetails(side: Side): ConnectionDetails = {
if (Side.ORIGIN.equals(side)) {
ConnectionDetails(
propertyHelper.getAsString(KnownProperties.CONNECT_ORIGIN_SCB),
propertyHelper.getAsString(KnownProperties.CONNECT_ORIGIN_HOST),
Expand All @@ -35,10 +37,11 @@ class ConnectionFetcher(sparkContext: SparkContext, propertyHelper: PropertyHelp
propertyHelper.getAsString(KnownProperties.ORIGIN_TLS_ENABLED),
propertyHelper.getAsString(KnownProperties.ORIGIN_TLS_TRUSTSTORE_PATH),
propertyHelper.getAsString(KnownProperties.ORIGIN_TLS_TRUSTSTORE_PASSWORD),
propertyHelper.getString(KnownProperties.ORIGIN_TLS_TRUSTSTORE_TYPE),
propertyHelper.getAsString(KnownProperties.ORIGIN_TLS_TRUSTSTORE_TYPE),
propertyHelper.getAsString(KnownProperties.ORIGIN_TLS_KEYSTORE_PATH),
propertyHelper.getAsString(KnownProperties.ORIGIN_TLS_KEYSTORE_PASSWORD),
propertyHelper.getAsString(KnownProperties.ORIGIN_TLS_ALGORITHMS)
propertyHelper.getAsString(KnownProperties.ORIGIN_TLS_ALGORITHMS),
propertyHelper.getBoolean(KnownProperties.ORIGIN_TLS_IS_ASTRA)
)
}
else {
Expand All @@ -51,15 +54,16 @@ class ConnectionFetcher(sparkContext: SparkContext, propertyHelper: PropertyHelp
propertyHelper.getAsString(KnownProperties.TARGET_TLS_ENABLED),
propertyHelper.getAsString(KnownProperties.TARGET_TLS_TRUSTSTORE_PATH),
propertyHelper.getAsString(KnownProperties.TARGET_TLS_TRUSTSTORE_PASSWORD),
propertyHelper.getString(KnownProperties.TARGET_TLS_TRUSTSTORE_TYPE),
propertyHelper.getAsString(KnownProperties.TARGET_TLS_TRUSTSTORE_TYPE),
propertyHelper.getAsString(KnownProperties.TARGET_TLS_KEYSTORE_PATH),
propertyHelper.getAsString(KnownProperties.TARGET_TLS_KEYSTORE_PASSWORD),
propertyHelper.getAsString(KnownProperties.TARGET_TLS_ALGORITHMS)
propertyHelper.getAsString(KnownProperties.TARGET_TLS_ALGORITHMS),
propertyHelper.getBoolean(KnownProperties.TARGET_TLS_IS_ASTRA)
)
}
}

def getConnection(side: String, consistencyLevel: String): CassandraConnector = {
def getConnection(side: Side, consistencyLevel: String): CassandraConnector = {
val connectionDetails = getConnectionDetails(side)
val config: SparkConf = sparkContext.getConf

Expand All @@ -72,6 +76,17 @@ class ConnectionFetcher(sparkContext: SparkContext, propertyHelper: PropertyHelp
.set("spark.cassandra.auth.password", connectionDetails.password)
.set("spark.cassandra.input.consistency.level", consistencyLevel)
.set("spark.cassandra.connection.config.cloud.path", connectionDetails.scbPath))
} else if (connectionDetails.trustStorePath.nonEmpty && connectionDetails.isAstra) {
logger.info("Connecting to Astra "+side+" (with truststore) using host metadata at "+connectionDetails.host+":"+connectionDetails.port);

val scbFile = generateSCB(connectionDetails.host, connectionDetails.port,
connectionDetails.trustStorePassword, connectionDetails.trustStorePath,
connectionDetails.keyStorePassword, connectionDetails.keyStorePath, side)
return CassandraConnector(config
.set("spark.cassandra.auth.username", connectionDetails.username)
.set("spark.cassandra.auth.password", connectionDetails.password)
.set("spark.cassandra.input.consistency.level", consistencyLevel)
.set("spark.cassandra.connection.config.cloud.path", "file://" + scbFile.getAbsolutePath()))
} else if (connectionDetails.trustStorePath.nonEmpty) {
logger.info("Connecting to "+side+" (with truststore) at "+connectionDetails.host+":"+connectionDetails.port);

Expand Down
3 changes: 3 additions & 0 deletions src/resources/cdm-detailed.properties
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,7 @@ spark.cdm.perfops.ratelimit.target 20000
# .path : Filepath to the Java keystore file
# .password : Password needed to open the keystore
# .enabledAlgorithms : Default is TLS_RSA_WITH_AES_128_CBC_SHA,TLS_RSA_WITH_AES_256_CBC_SHA
# .isAstra : Default is false. Set to true if connecting to DataStax Astra DB without SCB
#-----------------------------------------------------------------------------------------------------------
#spark.cdm.connect.origin.tls.enabled false
#spark.cdm.connect.origin.tls.trustStore.path
Expand All @@ -449,6 +450,7 @@ spark.cdm.perfops.ratelimit.target 20000
#spark.cdm.connect.origin.tls.keyStore.path
#spark.cdm.connect.origin.tls.keyStore.password
#spark.cdm.connect.origin.tls.enabledAlgorithms TLS_RSA_WITH_AES_128_CBC_SHA,TLS_RSA_WITH_AES_256_CBC_SHA
#spark.cdm.connect.origin.tls.isAstra false

#spark.cdm.connect.target.tls.enabled false
#spark.cdm.connect.target.tls.trustStore.path
Expand All @@ -457,3 +459,4 @@ spark.cdm.perfops.ratelimit.target 20000
#spark.cdm.connect.target.tls.keyStore.path
#spark.cdm.connect.target.tls.keyStore.password
#spark.cdm.connect.target.tls.enabledAlgorithms TLS_RSA_WITH_AES_128_CBC_SHA,TLS_RSA_WITH_AES_256_CBC_SHA
#spark.cdm.connect.target.tls.isAstra false
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,16 @@
*/
package com.datastax.cdm.cql.codec;

import static org.junit.jupiter.api.Assertions.*;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;

import java.util.List;

import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.Mockito;

import com.datastax.cdm.data.MockitoExtension;
import com.datastax.cdm.properties.PropertyHelper;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.Mock;

import com.datastax.cdm.data.CqlConversion;
import com.datastax.cdm.properties.KnownProperties;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
*/
package com.datastax.cdm.cql.statement;

import static org.junit.jupiter.api.Assertions.assertAll;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.Mockito.mock;

import java.util.Arrays;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,14 @@
*/
package com.datastax.cdm.cql.statement;

import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.Mockito.*;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.when;

import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
Expand All @@ -34,7 +35,6 @@
import com.datastax.oss.driver.api.core.cql.PreparedStatement;
import com.datastax.oss.driver.api.core.cql.ResultSet;
import com.datastax.oss.driver.api.core.cql.Row;
import com.datastax.oss.driver.api.core.type.DataTypes;

public class TargetUpsertRunDetailsStatementTest extends CommonMocks {
@Mock
Expand Down
8 changes: 3 additions & 5 deletions src/test/java/com/datastax/cdm/data/CqlConversionTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,9 @@
*/
package com.datastax.cdm.data;

import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.Mockito.*;

import java.util.Collections;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertAll;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.Mockito.mock;

import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
Expand Down
Loading

0 comments on commit 0bab9e9

Please sign in to comment.