diff --git a/clients/spark/core/src/main/scala/io/treeverse/clients/ApiClient.scala b/clients/spark/core/src/main/scala/io/treeverse/clients/ApiClient.scala index 5680ca27594..5139b2a93d2 100644 --- a/clients/spark/core/src/main/scala/io/treeverse/clients/ApiClient.scala +++ b/clients/spark/core/src/main/scala/io/treeverse/clients/ApiClient.scala @@ -2,6 +2,11 @@ package io.treeverse.clients import com.google.common.cache.CacheBuilder import io.lakefs.clients.api +import io.lakefs.clients.api.RetentionApi +import io.lakefs.clients.api.model.{ + GarbageCollectionPrepareRequest, + GarbageCollectionPrepareResponse +} import java.net.URI import java.util.concurrent.TimeUnit @@ -31,6 +36,7 @@ class ApiClient(apiUrl: String, accessKey: String, secretKey: String) { private val commitsApi = new api.CommitsApi(client) private val metadataApi = new api.MetadataApi(client) private val branchesApi = new api.BranchesApi(client) + private val retentionApi = new RetentionApi(client) private val storageNamespaceCache = CacheBuilder.newBuilder().expireAfterWrite(2, TimeUnit.MINUTES).build[String, String]() @@ -50,13 +56,24 @@ class ApiClient(apiUrl: String, accessKey: String, secretKey: String) { ) } + def prepareGarbageCollectionCommits( + repoName: String, + previousRunID: String + ): GarbageCollectionPrepareResponse = { + retentionApi.prepareGarbageCollectionCommits( + repoName, + new GarbageCollectionPrepareRequest().previousRunId(previousRunID) + ) + } + def getMetaRangeURL(repoName: String, commitID: String): String = { val commit = commitsApi.getCommit(repoName, commitID) val metaRangeID = commit.getMetaRangeId - - val metaRange = metadataApi.getMetaRange(repoName, metaRangeID) - val location = metaRange.getLocation - URI.create(getStorageNamespace(repoName) + "/" + location).normalize().toString + if (metaRangeID != "") { + val metaRange = metadataApi.getMetaRange(repoName, metaRangeID) + val location = metaRange.getLocation + URI.create(getStorageNamespace(repoName) + "/").resolve(location).normalize().toString + } else "" } def getRangeURL(repoName: String, rangeID: String): String = { diff --git a/clients/spark/core/src/main/scala/io/treeverse/clients/GarbageCollector.scala b/clients/spark/core/src/main/scala/io/treeverse/clients/GarbageCollector.scala index 3c5fa93a247..cf374368ab6 100644 --- a/clients/spark/core/src/main/scala/io/treeverse/clients/GarbageCollector.scala +++ b/clients/spark/core/src/main/scala/io/treeverse/clients/GarbageCollector.scala @@ -7,7 +7,9 @@ import io.treeverse.clients.LakeFSContext.{ LAKEFS_CONF_API_URL_KEY } import org.apache.hadoop.conf.Configuration +import org.apache.spark.rdd.RDD import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.{StringType, StructField, StructType} import org.apache.spark.sql.{SparkSession, _} import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration import software.amazon.awssdk.core.retry.RetryPolicy @@ -28,7 +30,6 @@ object GarbageCollector { .option("header", value = true) .option("inferSchema", value = true) .csv(commitDFLocation) - .where(col("run_id") === runID) } private def getRangeTuples( @@ -38,13 +39,16 @@ object GarbageCollector { ): Set[(String, Array[Byte], Array[Byte])] = { val location = new ApiClient(conf.apiURL, conf.accessKey, conf.secretKey).getMetaRangeURL(repo, commitID) - SSTableReader - .forMetaRange(new Configuration(), location) - .newIterator() - .map(range => - (new String(range.id), range.message.minKey.toByteArray, range.message.maxKey.toByteArray) - ) - .toSet + // continue on empty location, empty location is a result of a commit with no metaRangeID (e.g 'Repository created' commit) + if (location == "") Set() + else + SSTableReader + .forMetaRange(new Configuration(), location) + .newIterator() + .map(range => + (new String(range.id), range.message.minKey.toByteArray, range.message.maxKey.toByteArray) + ) + .toSet } def getRangesDFFromCommits( @@ -67,14 +71,18 @@ object GarbageCollector { .distinct } - def getRangeAddresses(rangeID: String, conf: APIConfigurations, repo: String): Set[String] = { + def getRangeAddresses( + rangeID: String, + conf: APIConfigurations, + repo: String + ): Seq[String] = { val location = new ApiClient(conf.apiURL, conf.accessKey, conf.secretKey).getRangeURL(repo, rangeID) SSTableReader .forRange(new Configuration(), location) .newIterator() - .map(a => new String(a.key)) - .toSet + .map(a => a.message.address) + .toSeq } def getEntryTuples( @@ -209,48 +217,84 @@ object GarbageCollector { ): Dataset[Row] = { val commitsDF = getCommitsDF(runID, commitDFLocation, spark) val rangesDF = getRangesDFFromCommits(commitsDF, repo, conf) - getExpiredEntriesFromRanges(rangesDF, conf, repo) + val expired = getExpiredEntriesFromRanges(rangesDF, conf, repo) + + val activeRangesDF = rangesDF.where("!expired") + subtractDeduplications(expired, activeRangesDF, conf, repo, spark) + } + + private def subtractDeduplications( + expired: Dataset[Row], + activeRangesDF: Dataset[Row], + conf: APIConfigurations, + repo: String, + spark: SparkSession + ): Dataset[Row] = { + val activeRangesRDD: RDD[String] = + activeRangesDF.select("range_id").rdd.distinct().map(x => x.getString(0)) + val activeAddresses: RDD[String] = activeRangesRDD + .flatMap(range => { + getRangeAddresses(range, conf, repo) + }) + .distinct() + val activeAddressesRows: RDD[Row] = activeAddresses.map(x => Row(x)) + val schema = new StructType().add(StructField("address", StringType, true)) + val activeDF = spark.createDataFrame(activeAddressesRows, schema) + // remove active addresses from delete candidates + expired.join( + activeDF, + expired("address") === activeDF("address"), + "leftanti" + ) } def main(args: Array[String]) { val spark = SparkSession.builder().getOrCreate() - if (args.length != 4) { + if (args.length != 2) { Console.err.println( - "Usage: ... s3://storageNamespace/prepared_commits_table s3://storageNamespace/output_destination_table" + "Usage: ... " ) System.exit(1) } val repo = args(0) - val runID = args(1) - val commitDFLocation = args(2) - val addressesDFLocation = args(3) - + val region = args(1) + val previousRunID = + "" //args(2) // TODO(Guys): get previous runID from arguments or from storage val hc = spark.sparkContext.hadoopConfiguration val apiURL = hc.get(LAKEFS_CONF_API_URL_KEY) val accessKey = hc.get(LAKEFS_CONF_API_ACCESS_KEY_KEY) val secretKey = hc.get(LAKEFS_CONF_API_SECRET_KEY_KEY) + val res = new ApiClient(apiURL, accessKey, secretKey) + .prepareGarbageCollectionCommits(repo, previousRunID) + val runID = res.getRunId + + val gcCommitsLocation = ApiClient.translateS3(new URI(res.getGcCommitsLocation)).toString + val gcAddressesLocation = ApiClient.translateS3(new URI(res.getGcAddressesLocation)).toString + val expiredAddresses = getExpiredAddresses(repo, runID, - commitDFLocation, + gcCommitsLocation, spark, APIConfigurations(apiURL, accessKey, secretKey) ).withColumn("run_id", lit(runID)) + spark.conf.set("spark.sql.sources.partitionOverwriteMode", "dynamic") expiredAddresses.write .partitionBy("run_id") - .mode(SaveMode.Append) - .parquet(addressesDFLocation) // TODO(Guys): consider changing to overwrite + .mode(SaveMode.Overwrite) + .parquet(gcAddressesLocation) + S3BulkDeleter.remove(repo, gcAddressesLocation, runID, region, spark) } } object S3BulkDeleter { - def repartitionBySize(df: DataFrame, maxSize: Int, column: String): DataFrame = { + private def repartitionBySize(df: DataFrame, maxSize: Int, column: String): DataFrame = { val nRows = df.count() val nPartitions = math.max(1, math.floor(nRows / maxSize)).toInt df.repartitionByRange(nPartitions, col(column)) } - def delObjIteration( + private def delObjIteration( bucket: String, keys: Seq[String], s3Client: S3Client, @@ -297,21 +341,15 @@ object S3BulkDeleter { }) } - def main(args: Array[String]): Unit = { - if (args.length != 5) { - Console.err.println( - "Usage: ... s3://storageNamespace/prepared_addresses_table s3://storageNamespace/output_destination_table" - ) - System.exit(1) - } + def remove( + repo: String, + addressDFLocation: String, + runID: String, + region: String, + spark: SparkSession + ) = { val MaxBulkSize = 1000 val awsRetries = 1000 - val repo = args(0) - val runID = args(1) - val region = args(2) - val addressesDFLocation = args(3) - val deletedAddressesDFLocation = args(4) - val spark = SparkSession.builder().getOrCreate() val hc = spark.sparkContext.hadoopConfiguration val apiURL = hc.get(LAKEFS_CONF_API_URL_KEY) @@ -326,16 +364,27 @@ object S3BulkDeleter { if (addSuffixSlash.startsWith("/")) addSuffixSlash.substring(1) else addSuffixSlash val df = spark.read - .parquet(addressesDFLocation) + .parquet(addressDFLocation) .where(col("run_id") === runID) .where(col("relative") === true) val res = - bulkRemove(df, MaxBulkSize, spark, bucket, region, awsRetries, snPrefix).toDF("addresses") - res - .withColumn("run_id", lit(runID)) - .write - .partitionBy("run_id") - .mode(SaveMode.Append) - .parquet(deletedAddressesDFLocation) + bulkRemove(df, MaxBulkSize, spark, bucket, region, awsRetries, snPrefix) + .toDF("addresses") + .collect() + } + + def main(args: Array[String]): Unit = { + if (args.length != 4) { + Console.err.println( + "Usage: ... s3://storageNamespace/prepared_addresses_table" + ) + System.exit(1) + } + val repo = args(0) + val runID = args(1) + val region = args(2) + val addressesDFLocation = args(3) + val spark = SparkSession.builder().getOrCreate() + remove(repo, addressesDFLocation, runID, region, spark) } }