@@ -26,7 +26,12 @@ import com.johnsnowlabs.nlp.annotators.ner.{ModelMetrics, NerApproach, Verbose}
2626import com .johnsnowlabs .nlp .annotators .param .EvaluationDLParams
2727import com .johnsnowlabs .nlp .training .NerDLDataLoader
2828import com .johnsnowlabs .nlp .util .io .{OutputHelper , ResourceHelper }
29- import com .johnsnowlabs .nlp .{AnnotatorApproach , AnnotatorType , ParamsAndFeaturesWritable }
29+ import com .johnsnowlabs .nlp .{
30+ Annotation ,
31+ AnnotatorApproach ,
32+ AnnotatorType ,
33+ ParamsAndFeaturesWritable
34+ }
3035import com .johnsnowlabs .storage .HasStorageRef
3136import org .apache .commons .io .IOUtils
3237import org .apache .commons .lang3 .SystemUtils
@@ -457,8 +462,26 @@ class NerDLApproach(override val uid: String)
457462 " Number of batches to prefetch while training using memory optimizer. Has no effect if memory optimizer is disabled." )
458463
459464 def getPrefetchBatches : Int = $(this .prefetchBatches)
465+
466+ /** Sets number of batches to prefetch while training using memory optimizer. Has no effect if
467+ * memory optimizer is disabled.
468+ * @group setParam
469+ */
460470 def setPrefetchBatches (value : Int ): this .type = set(this .prefetchBatches, value)
461471
472+ val optimizePartitioning = new BooleanParam (
473+ this ,
474+ " optimizePartitioning" ,
475+ " Whether to repartition the dataset before training for optimal performance. Has no effect if memory optimizer is disabled." )
476+
477+ def getOptimizePartitioning : Boolean = $(this .optimizePartitioning)
478+
479+ /** Sets whether to repartition the dataset before training for optimal performance. Has no
480+ * effect if memory optimizer is disabled.
481+ * @group setParam
482+ */
483+ def setOptimizePartitioning (value : Boolean ): this .type = set(this .optimizePartitioning, value)
484+
462485 setDefault(
463486 minEpochs -> 0 ,
464487 maxEpochs -> 70 ,
@@ -472,7 +495,8 @@ class NerDLApproach(override val uid: String)
472495 enableMemoryOptimizer -> false ,
473496 useBestModel -> false ,
474497 bestModelMetric -> ModelMetrics .loss,
475- prefetchBatches -> 0 )
498+ prefetchBatches -> 0 ,
499+ optimizePartitioning -> true )
476500
477501 override val verboseLevel : Verbose .Level = Verbose ($(verbose))
478502
@@ -488,86 +512,112 @@ class NerDLApproach(override val uid: String)
488512 LoadsContrib .loadContribToTensorflow()
489513 }
490514
515+ private def getIteratorFunc (split : Dataset [Row ]) = if (! getEnableMemoryOptimizer) {
516+ // No memory optimizer
517+ NerDLApproach .getIteratorFunc(
518+ split,
519+ inputColumns = getInputCols,
520+ labelColumn = $(labelColumn),
521+ batchSize = $(batchSize),
522+ enableMemoryOptimizer = $(enableMemoryOptimizer))
523+ } else {
524+ logger.info(s " Using memory optimizer with $prefetchBatches prefetch batches. " )
525+ NerDLApproach .getIteratorFunc(
526+ split,
527+ inputColumns = getInputCols,
528+ labelColumn = $(labelColumn),
529+ batchSize = $(batchSize),
530+ prefetchBatches = getPrefetchBatches)
531+ }
532+
533+ /** Extracts graph parameters and returns an optimized dataframe for training.
534+ *
535+ * @param dataset
536+ * input dataset
537+ * @return
538+ * (labels, chars, embeddingsDim, dsLen, optimizedDataset)
539+ */
540+ private def prepareData (dataset : Dataset [Row ])
541+ : (mutable.Set [AnnotatorType ], mutable.Set [Char ], Int , Long , Dataset [Row ]) = {
542+ def optimizePartitioning (ds : Dataset [Row ], dsLen : Long ): Dataset [Row ] = {
543+ if (getEnableMemoryOptimizer && getOptimizePartitioning) {
544+ // Repartition cachedDataset according to heuristic:
545+ // Assume one row contains about 1 MB of data (BertEmbeddings), and spark recommends 100MB to 1GB partitions.
546+ // We'll go for the middle ground of 500MB means that one partition should hold 500 rows
547+ val numPartitions = math.ceil(dsLen / 500.0 ).toInt
548+ logger.info(
549+ s " Repartitioning input cachedDataset to $numPartitions partitions for NerDL training. " )
550+ ds.repartition(numPartitions)
551+ } else ds
552+ }
553+
554+ val cachedDataset : Dataset [Row ] = dataset.cache().toDF()
555+ NerDLApproach .getDataSetParamsFromMetadata(cachedDataset, $(labelColumn)) match {
556+ // metadata contains the length of the entire cachedDataset, so we can avoid a count() action
557+ case Some (
558+ (
559+ labels : mutable.Set [AnnotatorType ],
560+ chars : mutable.Set [Char ],
561+ embeddingsDim : Int ,
562+ dsLen : Long )) =>
563+ // Only repartition if using memory optimizer
564+ val repartitionedDataset = optimizePartitioning(cachedDataset, dsLen)
565+ (labels, chars, embeddingsDim, dsLen, repartitionedDataset)
566+ case None => // Legacy way of getting cachedDataset params
567+ logger.info(" Dataset metadata does not contain graph parameters, extracting from data." )
568+ val docColumn =
569+ Annotation .getColumnByType(dataset, getInputCols, AnnotatorType .DOCUMENT ).name
570+ val dsLen = cachedDataset.selectExpr(s " explode( $docColumn) " ).count()
571+ // Repartition now, so we don't OOM when extracting params
572+ val repartitionedDataset = optimizePartitioning(cachedDataset, dsLen)
573+ val (
574+ labels : mutable.Set [AnnotatorType ],
575+ chars : mutable.Set [Char ],
576+ embeddingsDim : Int ,
577+ _) =
578+ NerDLApproach .getDataSetParams(getIteratorFunc(repartitionedDataset)())
579+ (labels, chars, embeddingsDim, dsLen, repartitionedDataset)
580+ }
581+ }
582+
491583 override def train (
492584 dataset : Dataset [_],
493585 recursivePipeline : Option [PipelineModel ]): NerDLModel = {
494586 require(
495587 $(validationSplit) <= 1f | $(validationSplit) >= 0f ,
496588 " The validationSplit must be between 0f and 1f" )
497589
498- def getIteratorFunc (split : Dataset [Row ]) = if (! getEnableMemoryOptimizer) {
499- // No memory optimizer
500- NerDLApproach .getIteratorFunc(
501- split,
502- inputColumns = getInputCols,
503- labelColumn = $(labelColumn),
504- batchSize = $(batchSize),
505- enableMemoryOptimizer = $(enableMemoryOptimizer))
506- } else {
507- logger.info(s " Using memory optimizer with $prefetchBatches prefetch batches. " )
508- NerDLApproach .getIteratorFunc(
509- split,
510- inputColumns = getInputCols,
511- labelColumn = $(labelColumn),
512- batchSize = $(batchSize),
513- prefetchBatches = getPrefetchBatches)
514- }
515-
516590 val embeddingsRef =
517591 HasStorageRef .getStorageRefFromInput(dataset, $(inputCols), AnnotatorType .WORD_EMBEDDINGS )
518592
593+ val (
594+ labels : mutable.Set [AnnotatorType ],
595+ chars : mutable.Set [Char ],
596+ embeddingsDim : Int ,
597+ dsLen : Long ,
598+ optimizedDataset : Dataset [Row ]) = prepareData(dataset.toDF())
599+ val trainDsLen = math.round(dsLen * (1.0f - $(validationSplit)))
600+ val valDsLen = dsLen - trainDsLen
601+
519602 // Get the data splits
520603 val (trainSplit : Dataset [Row ], validSplit : Dataset [Row ], test : Dataset [Row ]) = {
521- def cacheIfNeeded (ds : Dataset [Row ]): Dataset [Row ] =
522- if (getEnableMemoryOptimizer && getMaxEpochs > 1 ) ds.cache() else ds
523-
524- val train = dataset.toDF()
525604 val (validSplit, trainSplit) =
526- train .randomSplit(Array ($(validationSplit), 1.0f - $(validationSplit))) match {
605+ optimizedDataset .randomSplit(Array ($(validationSplit), 1.0f - $(validationSplit))) match {
527606 case Array (validSplit, trainSplit) => (validSplit, trainSplit)
528607 }
529608
530609 val test =
531- if (! isDefined(testDataset)) train .limit(0 ) // keep the schema only
610+ if (! isDefined(testDataset)) optimizedDataset .limit(0 ) // keep the schema only
532611 else ResourceHelper .readSparkDataFrame($(testDataset))
533612
534- (cacheIfNeeded( trainSplit), cacheIfNeeded( validSplit), cacheIfNeeded( test) )
613+ (trainSplit, validSplit, test)
535614 }
536615
537- // TODO DHA: Better way to do this?
616+ // Get Iterators
538617 val trainIteratorFunc = getIteratorFunc(trainSplit)
539618 val validIteratorFunc = getIteratorFunc(validSplit)
540619 val testIteratorFunc = getIteratorFunc(test)
541620
542- val (
543- labels : mutable.Set [AnnotatorType ],
544- chars : mutable.Set [Char ],
545- embeddingsDim : Int ,
546- trainDsLen : Long ,
547- valDsLen : Long ) = {
548- NerDLApproach .getDataSetParamsFromMetadata(trainSplit, $(labelColumn)) match {
549- // metadata contains the length of the entire dataset
550- case Some (
551- (
552- labels : mutable.Set [AnnotatorType ],
553- chars : mutable.Set [Char ],
554- embeddingsDim : Int ,
555- dsLen : Long )) =>
556- val trainDsLen = math.round(dsLen * (1.0f - $(validationSplit)))
557- val valDsLen = dsLen - trainDsLen
558- (labels, chars, embeddingsDim, trainDsLen.toLong, valDsLen)
559- case None => // Legacy way of getting dataset params
560- val (
561- labels : mutable.Set [AnnotatorType ],
562- chars : mutable.Set [Char ],
563- embeddingsDim : Int ,
564- trainDsLen : Long ) = NerDLApproach .getDataSetParams(trainIteratorFunc())
565- val valDsLen : Long =
566- math.round(trainDsLen / (1 - $(validationSplit)) * $(validationSplit))
567- (labels, chars, embeddingsDim, trainDsLen, valDsLen)
568- }
569- }
570-
571621 val settings = DatasetEncoderParams (
572622 labels.toList,
573623 chars.toList,
@@ -641,9 +691,7 @@ class NerDLApproach(override val uid: String)
641691 if (get(configProtoBytes).isDefined)
642692 model.setConfigProtoBytes($(configProtoBytes))
643693
644- trainSplit.unpersist()
645- validSplit.unpersist()
646- test.unpersist()
694+ optimizedDataset.unpersist()
647695 model
648696 }
649697}
@@ -860,7 +908,7 @@ object NerDLApproach extends DefaultParamsReadable[NerDLApproach] with WithGraph
860908 val chars = metadata.getStringArray(NerDLGraphCheckerModel .charsKey).map(_.head)
861909 val embeddingsDim = metadata.getLong(NerDLGraphCheckerModel .embeddingsDimKey).toInt
862910 val dsLen = metadata.getLong(NerDLGraphCheckerModel .dsLenKey)
863- logger.info(s " NerDLApproach: Found graph params in label column metadata:" +
911+ logger.info(s " Found graph params in label column metadata: " +
864912 s " labels= ${labels.length}, chars= ${chars.length}, embeddingsDim= $embeddingsDim, dsLen= $dsLen" )
865913
866914 Some (
0 commit comments