Skip to content

Commit 6580ad5

Browse files
committed
NerDLApproach: Optimize partitioning flag
Allow NerDLApproach to repartition the input dataset, so the driver does not go out of memory when training on large partitions.
1 parent 60227de commit 6580ad5

File tree

1 file changed

+109
-61
lines changed

1 file changed

+109
-61
lines changed

src/main/scala/com/johnsnowlabs/nlp/annotators/ner/dl/NerDLApproach.scala

Lines changed: 109 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,12 @@ import com.johnsnowlabs.nlp.annotators.ner.{ModelMetrics, NerApproach, Verbose}
2626
import com.johnsnowlabs.nlp.annotators.param.EvaluationDLParams
2727
import com.johnsnowlabs.nlp.training.NerDLDataLoader
2828
import com.johnsnowlabs.nlp.util.io.{OutputHelper, ResourceHelper}
29-
import com.johnsnowlabs.nlp.{AnnotatorApproach, AnnotatorType, ParamsAndFeaturesWritable}
29+
import com.johnsnowlabs.nlp.{
30+
Annotation,
31+
AnnotatorApproach,
32+
AnnotatorType,
33+
ParamsAndFeaturesWritable
34+
}
3035
import com.johnsnowlabs.storage.HasStorageRef
3136
import org.apache.commons.io.IOUtils
3237
import org.apache.commons.lang3.SystemUtils
@@ -457,8 +462,26 @@ class NerDLApproach(override val uid: String)
457462
"Number of batches to prefetch while training using memory optimizer. Has no effect if memory optimizer is disabled.")
458463

459464
def getPrefetchBatches: Int = $(this.prefetchBatches)
465+
466+
/** Sets number of batches to prefetch while training using memory optimizer. Has no effect if
467+
* memory optimizer is disabled.
468+
* @group setParam
469+
*/
460470
def setPrefetchBatches(value: Int): this.type = set(this.prefetchBatches, value)
461471

472+
val optimizePartitioning = new BooleanParam(
473+
this,
474+
"optimizePartitioning",
475+
"Whether to repartition the dataset before training for optimal performance. Has no effect if memory optimizer is disabled.")
476+
477+
def getOptimizePartitioning: Boolean = $(this.optimizePartitioning)
478+
479+
/** Sets whether to repartition the dataset before training for optimal performance. Has no
480+
* effect if memory optimizer is disabled.
481+
* @group setParam
482+
*/
483+
def setOptimizePartitioning(value: Boolean): this.type = set(this.optimizePartitioning, value)
484+
462485
setDefault(
463486
minEpochs -> 0,
464487
maxEpochs -> 70,
@@ -472,7 +495,8 @@ class NerDLApproach(override val uid: String)
472495
enableMemoryOptimizer -> false,
473496
useBestModel -> false,
474497
bestModelMetric -> ModelMetrics.loss,
475-
prefetchBatches -> 0)
498+
prefetchBatches -> 0,
499+
optimizePartitioning -> true)
476500

477501
override val verboseLevel: Verbose.Level = Verbose($(verbose))
478502

@@ -488,86 +512,112 @@ class NerDLApproach(override val uid: String)
488512
LoadsContrib.loadContribToTensorflow()
489513
}
490514

515+
private def getIteratorFunc(split: Dataset[Row]) = if (!getEnableMemoryOptimizer) {
516+
// No memory optimizer
517+
NerDLApproach.getIteratorFunc(
518+
split,
519+
inputColumns = getInputCols,
520+
labelColumn = $(labelColumn),
521+
batchSize = $(batchSize),
522+
enableMemoryOptimizer = $(enableMemoryOptimizer))
523+
} else {
524+
logger.info(s"Using memory optimizer with $prefetchBatches prefetch batches.")
525+
NerDLApproach.getIteratorFunc(
526+
split,
527+
inputColumns = getInputCols,
528+
labelColumn = $(labelColumn),
529+
batchSize = $(batchSize),
530+
prefetchBatches = getPrefetchBatches)
531+
}
532+
533+
/** Extracts graph parameters and returns an optimized dataframe for training.
534+
*
535+
* @param dataset
536+
* input dataset
537+
* @return
538+
* (labels, chars, embeddingsDim, dsLen, optimizedDataset)
539+
*/
540+
private def prepareData(dataset: Dataset[Row])
541+
: (mutable.Set[AnnotatorType], mutable.Set[Char], Int, Long, Dataset[Row]) = {
542+
def optimizePartitioning(ds: Dataset[Row], dsLen: Long): Dataset[Row] = {
543+
if (getEnableMemoryOptimizer && getOptimizePartitioning) {
544+
// Repartition cachedDataset according to heuristic:
545+
// Assume one row contains about 1 MB of data (BertEmbeddings), and spark recommends 100MB to 1GB partitions.
546+
// We'll go for the middle ground of 500MB means that one partition should hold 500 rows
547+
val numPartitions = math.ceil(dsLen / 500.0).toInt
548+
logger.info(
549+
s"Repartitioning input cachedDataset to $numPartitions partitions for NerDL training.")
550+
ds.repartition(numPartitions)
551+
} else ds
552+
}
553+
554+
val cachedDataset: Dataset[Row] = dataset.cache().toDF()
555+
NerDLApproach.getDataSetParamsFromMetadata(cachedDataset, $(labelColumn)) match {
556+
// metadata contains the length of the entire cachedDataset, so we can avoid a count() action
557+
case Some(
558+
(
559+
labels: mutable.Set[AnnotatorType],
560+
chars: mutable.Set[Char],
561+
embeddingsDim: Int,
562+
dsLen: Long)) =>
563+
// Only repartition if using memory optimizer
564+
val repartitionedDataset = optimizePartitioning(cachedDataset, dsLen)
565+
(labels, chars, embeddingsDim, dsLen, repartitionedDataset)
566+
case None => // Legacy way of getting cachedDataset params
567+
logger.info("Dataset metadata does not contain graph parameters, extracting from data.")
568+
val docColumn =
569+
Annotation.getColumnByType(dataset, getInputCols, AnnotatorType.DOCUMENT).name
570+
val dsLen = cachedDataset.selectExpr(s"explode($docColumn)").count()
571+
// Repartition now, so we don't OOM when extracting params
572+
val repartitionedDataset = optimizePartitioning(cachedDataset, dsLen)
573+
val (
574+
labels: mutable.Set[AnnotatorType],
575+
chars: mutable.Set[Char],
576+
embeddingsDim: Int,
577+
_) =
578+
NerDLApproach.getDataSetParams(getIteratorFunc(repartitionedDataset)())
579+
(labels, chars, embeddingsDim, dsLen, repartitionedDataset)
580+
}
581+
}
582+
491583
override def train(
492584
dataset: Dataset[_],
493585
recursivePipeline: Option[PipelineModel]): NerDLModel = {
494586
require(
495587
$(validationSplit) <= 1f | $(validationSplit) >= 0f,
496588
"The validationSplit must be between 0f and 1f")
497589

498-
def getIteratorFunc(split: Dataset[Row]) = if (!getEnableMemoryOptimizer) {
499-
// No memory optimizer
500-
NerDLApproach.getIteratorFunc(
501-
split,
502-
inputColumns = getInputCols,
503-
labelColumn = $(labelColumn),
504-
batchSize = $(batchSize),
505-
enableMemoryOptimizer = $(enableMemoryOptimizer))
506-
} else {
507-
logger.info(s"Using memory optimizer with $prefetchBatches prefetch batches.")
508-
NerDLApproach.getIteratorFunc(
509-
split,
510-
inputColumns = getInputCols,
511-
labelColumn = $(labelColumn),
512-
batchSize = $(batchSize),
513-
prefetchBatches = getPrefetchBatches)
514-
}
515-
516590
val embeddingsRef =
517591
HasStorageRef.getStorageRefFromInput(dataset, $(inputCols), AnnotatorType.WORD_EMBEDDINGS)
518592

593+
val (
594+
labels: mutable.Set[AnnotatorType],
595+
chars: mutable.Set[Char],
596+
embeddingsDim: Int,
597+
dsLen: Long,
598+
optimizedDataset: Dataset[Row]) = prepareData(dataset.toDF())
599+
val trainDsLen = math.round(dsLen * (1.0f - $(validationSplit)))
600+
val valDsLen = dsLen - trainDsLen
601+
519602
// Get the data splits
520603
val (trainSplit: Dataset[Row], validSplit: Dataset[Row], test: Dataset[Row]) = {
521-
def cacheIfNeeded(ds: Dataset[Row]): Dataset[Row] =
522-
if (getEnableMemoryOptimizer && getMaxEpochs > 1) ds.cache() else ds
523-
524-
val train = dataset.toDF()
525604
val (validSplit, trainSplit) =
526-
train.randomSplit(Array($(validationSplit), 1.0f - $(validationSplit))) match {
605+
optimizedDataset.randomSplit(Array($(validationSplit), 1.0f - $(validationSplit))) match {
527606
case Array(validSplit, trainSplit) => (validSplit, trainSplit)
528607
}
529608

530609
val test =
531-
if (!isDefined(testDataset)) train.limit(0) // keep the schema only
610+
if (!isDefined(testDataset)) optimizedDataset.limit(0) // keep the schema only
532611
else ResourceHelper.readSparkDataFrame($(testDataset))
533612

534-
(cacheIfNeeded(trainSplit), cacheIfNeeded(validSplit), cacheIfNeeded(test))
613+
(trainSplit, validSplit, test)
535614
}
536615

537-
// TODO DHA: Better way to do this?
616+
// Get Iterators
538617
val trainIteratorFunc = getIteratorFunc(trainSplit)
539618
val validIteratorFunc = getIteratorFunc(validSplit)
540619
val testIteratorFunc = getIteratorFunc(test)
541620

542-
val (
543-
labels: mutable.Set[AnnotatorType],
544-
chars: mutable.Set[Char],
545-
embeddingsDim: Int,
546-
trainDsLen: Long,
547-
valDsLen: Long) = {
548-
NerDLApproach.getDataSetParamsFromMetadata(trainSplit, $(labelColumn)) match {
549-
// metadata contains the length of the entire dataset
550-
case Some(
551-
(
552-
labels: mutable.Set[AnnotatorType],
553-
chars: mutable.Set[Char],
554-
embeddingsDim: Int,
555-
dsLen: Long)) =>
556-
val trainDsLen = math.round(dsLen * (1.0f - $(validationSplit)))
557-
val valDsLen = dsLen - trainDsLen
558-
(labels, chars, embeddingsDim, trainDsLen.toLong, valDsLen)
559-
case None => // Legacy way of getting dataset params
560-
val (
561-
labels: mutable.Set[AnnotatorType],
562-
chars: mutable.Set[Char],
563-
embeddingsDim: Int,
564-
trainDsLen: Long) = NerDLApproach.getDataSetParams(trainIteratorFunc())
565-
val valDsLen: Long =
566-
math.round(trainDsLen / (1 - $(validationSplit)) * $(validationSplit))
567-
(labels, chars, embeddingsDim, trainDsLen, valDsLen)
568-
}
569-
}
570-
571621
val settings = DatasetEncoderParams(
572622
labels.toList,
573623
chars.toList,
@@ -641,9 +691,7 @@ class NerDLApproach(override val uid: String)
641691
if (get(configProtoBytes).isDefined)
642692
model.setConfigProtoBytes($(configProtoBytes))
643693

644-
trainSplit.unpersist()
645-
validSplit.unpersist()
646-
test.unpersist()
694+
optimizedDataset.unpersist()
647695
model
648696
}
649697
}
@@ -860,7 +908,7 @@ object NerDLApproach extends DefaultParamsReadable[NerDLApproach] with WithGraph
860908
val chars = metadata.getStringArray(NerDLGraphCheckerModel.charsKey).map(_.head)
861909
val embeddingsDim = metadata.getLong(NerDLGraphCheckerModel.embeddingsDimKey).toInt
862910
val dsLen = metadata.getLong(NerDLGraphCheckerModel.dsLenKey)
863-
logger.info(s"NerDLApproach: Found graph params in label column metadata:" +
911+
logger.info(s"Found graph params in label column metadata:" +
864912
s" labels=${labels.length}, chars=${chars.length}, embeddingsDim=$embeddingsDim, dsLen=$dsLen")
865913

866914
Some(

0 commit comments

Comments
 (0)