Skip to content

Commit 631b350

Browse files
committed
NerDL Optimizations python side
1 parent 6580ad5 commit 631b350

File tree

2 files changed

+92
-1
lines changed

2 files changed

+92
-1
lines changed

python/sparknlp/annotator/ner/ner_dl.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,14 @@ class NerDLApproach(AnnotatorApproach, NerApproach, EvaluationDLParams):
238238
"Whether to check F1 Micro-average or F1 Macro-average as a final metric for the best model.",
239239
TypeConverters.toString)
240240

241+
prefetchBatches = Param(Params._dummy(), "prefetchBatches",
242+
"Number of batches to prefetch while training using memory optimizer. Has no effect if memory optimizer is disabled.",
243+
TypeConverters.toInt)
244+
245+
optimizePartitioning = Param(Params._dummy(), "optimizePartitioning",
246+
"Whether to repartition the dataset before training for optimal performance. Has no effect if memory optimizer is disabled.",
247+
TypeConverters.toBoolean)
248+
241249
def setConfigProtoBytes(self, b):
242250
"""Sets configProto from tensorflow, serialized into byte array.
243251
@@ -377,6 +385,28 @@ def setBestModelMetric(self, value):
377385
"""
378386
return self._set(bestModelMetric=value)
379387

388+
def setPrefetchBatches(self, value):
389+
"""Sets number of batches to prefetch while training using memory optimizer.
390+
Has no effect if memory optimizer is disabled.
391+
392+
Parameters
393+
----------
394+
value : int
395+
Number of batches to prefetch
396+
"""
397+
return self._set(prefetchBatches=value)
398+
399+
def setOptimizePartitioning(self, value):
400+
"""Sets whether to repartition the dataset before training for optimal performance.
401+
Has no effect if memory optimizer is disabled.
402+
403+
Parameters
404+
----------
405+
value: bool
406+
Whether to optimize partitioning
407+
"""
408+
return self._set(optimizePartitioning=value)
409+
380410
def _create_model(self, java_model):
381411
return NerDLModel(java_model=java_model)
382412

@@ -400,7 +430,9 @@ def __init__(self):
400430
enableOutputLogs=False,
401431
enableMemoryOptimizer=False,
402432
useBestModel=False,
403-
bestModelMetric="f1_micro"
433+
bestModelMetric="f1_micro",
434+
prefetchBatches=0,
435+
optimizePartitioning=True
404436
)
405437

406438

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Copyright 2017-2025 John Snow Labs
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import unittest
15+
16+
import pytest
17+
18+
from sparknlp.annotator import *
19+
from test.util import SparkSessionForTest
20+
21+
22+
@pytest.mark.fast
23+
class NerDLApproachTestSpec(unittest.TestCase):
24+
def setUp(self):
25+
self.spark = SparkSessionForTest.spark
26+
27+
def test_setters(self):
28+
ner_approach = (
29+
NerDLApproach()
30+
.setLr(0.01)
31+
.setPo(0.005)
32+
.setBatchSize(16)
33+
.setDropout(0.01)
34+
.setGraphFolder("graph_folder")
35+
.setConfigProtoBytes([])
36+
.setUseContrib(False)
37+
.setEnableMemoryOptimizer(True)
38+
.setIncludeConfidence(True)
39+
.setIncludeAllConfidenceScores(True)
40+
.setUseBestModel(True)
41+
.setPrefetchBatches(20)
42+
.setOptimizePartitioning(True)
43+
)
44+
45+
# Check param map
46+
param_map = ner_approach.extractParamMap()
47+
self.assertEqual(param_map[ner_approach.lr], 0.01)
48+
self.assertEqual(param_map[ner_approach.po], 0.005)
49+
self.assertEqual(param_map[ner_approach.batchSize], 16)
50+
self.assertEqual(param_map[ner_approach.dropout], 0.01)
51+
self.assertEqual(param_map[ner_approach.graphFolder], "graph_folder")
52+
self.assertEqual(param_map[ner_approach.configProtoBytes], [])
53+
self.assertEqual(param_map[ner_approach.useContrib], False)
54+
self.assertEqual(param_map[ner_approach.enableMemoryOptimizer], True)
55+
self.assertEqual(param_map[ner_approach.includeConfidence], True)
56+
self.assertEqual(param_map[ner_approach.includeAllConfidenceScores], True)
57+
self.assertEqual(param_map[ner_approach.useBestModel], True)
58+
self.assertEqual(param_map[ner_approach.prefetchBatches], 20)
59+
self.assertEqual(param_map[ner_approach.optimizePartitioning], True)

0 commit comments

Comments
 (0)