Skip to content

Commit

Permalink
TLDR-462 -- style fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
raxtemur committed Oct 26, 2023
1 parent dd74a65 commit 70f6840
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 9 deletions.
4 changes: 2 additions & 2 deletions dedoc/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
# --------------------------------------------JOBLIB SETTINGS-------------------------------------------------------
# number of parallel jobs in some tasks as OCR
n_jobs=1,

# --------------------------------------------GPU SETTINGS-------------------------------------------------------
# set gpu in XGBoost and torch models
# set gpu in XGBoost and torch models
on_gpu=False,

# ---------------------------------------------API SETTINGS---------------------------------------------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,9 @@ def __get_model(self) -> XGBClassifier:
self.__model = pickle.load(f)

if self.config.get("on_gpu", False):
self.__model.set_params(predictor="gpu_predictor", tree_method='auto', n_gpus=1, gpu_id=0)
self.__model.set_params(predictor="gpu_predictor", tree_method="auto", n_gpus=1, gpu_id=0)
self.__model.get_booster().set_param(self.__model.get_params())


return self.__model

def predict(self, lines: List[LineWithMeta]) -> bool:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def _unpickle(self) -> None:
self._feature_extractor = ParagraphFeatureExtractor(**parameters, config=self.config)

if self.config.get("on_gpu", False):
self._classifier.set_params(predictor="gpu_predictor", tree_method='auto', n_gpus=1, gpu_id=0)
self._classifier.set_params(predictor="gpu_predictor", tree_method="auto", n_gpus=1, gpu_id=0)
self._classifier.get_booster().set_param(self._classifier.get_params())

def extract(self, lines_with_links: List[LineWithLocation]) -> List[LineWithLocation]:
Expand Down
6 changes: 2 additions & 4 deletions tests/unit_tests/test_my_gpu_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,14 @@
from dedoc.metadata_extractors.concrete_metadata_extractors.base_metadata_extractor import BaseMetadataExtractor
from dedoc.readers.txt_reader.raw_text_reader import RawTextReader
from dedoc.structure_extractors.concrete_structure_extractors.law_structure_excractor import LawStructureExtractor
from tests.api_tests.abstract_api_test import AbstractTestApiDocReader
from dedoc.readers.pdf_reader.pdf_image_reader.columns_orientation_classifier.columns_orientation_classifier import ColumnsOrientationClassifier

from tests.api_tests.abstract_api_test import AbstractTestApiDocReader
from tests.test_utils import get_test_config


@unittest.skip("Should load gpu")
class MyGPUTests(AbstractTestApiDocReader):
config = dict(on_gpu=True)


def _get_abs_path(self, file_name: str) -> str:
return os.path.join(self.data_directory_path, "laws", file_name)
Expand All @@ -37,7 +36,6 @@ def test_law_document_spaces_correctness(self) -> None:
self.assertListEqual([], document.attachments)
self.assertListEqual([], document.tables)


def test_skew_corrector(self) -> None:
checkpoint_path = get_test_config()["resources_path"]
orientation_classifier = ColumnsOrientationClassifier(on_gpu=self.config.get("on_gpu", False), checkpoint_path=checkpoint_path, config=self.config)
Expand Down

0 comments on commit 70f6840

Please sign in to comment.