-
-
Notifications
You must be signed in to change notification settings - Fork 1
/
docutron.py
399 lines (295 loc) · 10.4 KB
/
docutron.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
# -*- coding: utf-8 -*-
"""
# Docutron Toolkit: detection and segmentation analysis for legal data extraction over documents
[![Python](https://img.shields.io/pypi/pyversions/tensorflow.svg)](https://badge.fury.io/py/tensorflow) [![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) ![Maintainer](https://img.shields.io/badge/maintainer-@louisbrulenaudet-blue)
Docutron is a tool designed to facilitate the extraction of relevant information from legal documents, enabling professionals to create datasets for fine-tuning language models (LLM) for specific legal domains.
# Collecting and installing dependencies
"""
import sys, os, distutils.core
!git clone 'https://github.com/louisbrulenaudet/detectron2'
dist = distutils.core.run_setup("./detectron2/setup.py")
!python -m pip install {' '.join([f"'{x}'" for x in dist.install_requires])}
sys.path.insert(0, os.path.abspath('./detectron2'))
"""# Importing packages"""
import cv2
import json
import locale
import numpy as np
import os
import random
import requests
import subprocess
import sys
import uuid
import zipfile
from io import BytesIO
import torch
import distutils.core
import detectron2
from google.colab.patches import cv2_imshow
from detectron2 import model_zoo
from detectron2.config import get_cfg
from detectron2.data import MetadataCatalog, DatasetCatalog
from detectron2.data.datasets import register_coco_instances
from detectron2.engine import DefaultPredictor, DefaultTrainer
from detectron2.utils.logger import setup_logger
from detectron2.utils.visualizer import Visualizer
setup_logger()
locale.getpreferredencoding = lambda: "UTF-8"
def hardware_log() -> None:
"""
Log hardware and software information related to the environment.
This function prints information about the installed PyTorch
version, CUDA version, and Detectron2 version.
Parameters
----------
None
Returns
-------
None
Example
-------
>>> hardware_log()
torch: 1.9; cuda: 11.1
detectron2: [detectron2 version]
"""
!nvcc --version
TORCH_VERSION = ".".join(torch.__version__.split(".")[:2])
CUDA_VERSION = torch.__version__.split("+")[-1]
print("torch: ", TORCH_VERSION, "; cuda: ", CUDA_VERSION)
print("detectron2:", detectron2.__version__)
return None
hardware_log()
"""# Datasets registration"""
def download_dataset(url: str, destination:str="./") -> None:
"""
Download and extract a dataset from a specified URL.
This function fetches a dataset from the provided URL, checks
the file type, and extracts its contents.
Parameters
----------
url : str
The URL of the dataset to download.
destination : str, optional
The destination path for the dataset storage.
Returns
-------
None
"""
response = requests.get(url)
with zipfile.ZipFile(BytesIO(response.content), 'r') as zip_ref:
zip_ref.extractall(destination)
return None
download_dataset(
url="https://github.com/louisbrulenaudet/docutron/raw/main/datasets.zip"
)
def register_dataset(annotations_path:str, img_dir:str) -> str:
"""
Register COCO instances for dataset.
Parameters
----------
annotations_path : str
The path to the training annotations JSON file.
img_dir : str
The directory for the training dataset.
Returns
-------
database_name : str
The name of the registered training dataset.
"""
dataset_name = str(uuid.uuid4())
register_coco_instances(dataset_name, {}, annotations_path, img_dir)
return dataset_name
dataset_train = register_dataset(
annotations_path="datasets/train/annotations.json",
img_dir="datasets/train"
)
dataset_val = register_dataset(
annotations_path="datasets/val/annotations.json",
img_dir="datasets/val"
)
def get_metadata(dataset:str) -> any:
"""
Retrieve metadata for a registered dataset.
Parameters
----------
dataset : str
The name of the registered dataset for which metadata is to be retrieved.
Returns
-------
metadata : any
Metadata associated with the dataset, which may include information about class labels,
image statistics, and other dataset-specific details.
"""
metadata = MetadataCatalog.get(dataset)
return metadata
metadata = get_metadata(dataset=dataset_train)
def visualize_dataset_samples(dataset:str, metadata:list, num_samples:int=1, scale:int=1) -> None:
"""
Visualize random samples from a registered dataset.
Parameters
----------
dataset : str
The name of the registered dataset to visualize.
num_samples : int, optional
The number of samples to visualize. Default is 1.
scale : int, optional
The scale of the test image to be displayed. Default is 1.
Returns
-------
None
"""
dataset_dicts = DatasetCatalog.get(dataset)
# Randomly select and visualize dataset samples
for _ in range(num_samples):
d = random.choice(dataset_dicts)
img = cv2.imread(d["file_name"])
visualizer = Visualizer(
img[:, :, ::-1],
metadata=metadata,
scale=scale
)
vis = visualizer.draw_dataset_dict(d)
cv2_imshow(vis.get_image()[:, :, ::-1])
cv2.waitKey(0)
return None
visualize_dataset_samples(dataset=dataset_train, metadata=metadata, num_samples=1)
"""# Model configuration"""
def model_configuration(dataset_train, model:str="COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml", device:str="cuda", num_workers:int=2, batch_size:int=2, learning_rate:float=0.00025, iter:int=1000, bs_per_image:int=128, num_classes:int=1) -> any:
"""
Set up the configuration for the object detection model.
This function configures the model for training by specifying various
parameters and settings.
Parameters
----------
dataset_train : str
The name of the training dataset to be used for training.
model : str, optional
The configuration file for the object detection model. Default is "COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml".
device : str, optional
The device used for computation. Default is cuda.
num_workers : int, optional
The number of data loader workers for training. Default is 2.
batch_size : int, optional
The batch size for training. Default is 2.
learning_rate : float, optional
The initial learning rate for training. Default is 0.00025.
iter : int, optional
The maximum number of training iterations. Default is 1000.
bs_per_image : int, optional
The batch size per image for training. Default is 128.
num_classes : int, optional
The number of object classes to detect. Default is 1.
Returns
-------
cfg : CfgNode
The configuration object for the object detection model.
"""
cfg = get_cfg()
if not torch.cuda.is_available() or device == "cpu":
cfg.MODEL.DEVICE = "cpu"
cfg.merge_from_file(model_zoo.get_config_file(model))
cfg.DATASETS.TRAIN = (dataset_train)
cfg.DATASETS.TEST = ()
cfg.DATALOADER.NUM_WORKERS = num_workers
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(model)
cfg.SOLVER.IMS_PER_BATCH = batch_size
cfg.SOLVER.BASE_LR = learning_rate
cfg.SOLVER.MAX_ITER = iter
cfg.SOLVER.STEPS = []
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = bs_per_image
cfg.MODEL.ROI_HEADS.NUM_CLASSES = num_classes
os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
return cfg
cfg = model_configuration(
dataset_train=dataset_train,
model="COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml",
device="cuda",
num_workers=2,
batch_size=2,
learning_rate=0.00025,
iter=1000,
bs_per_image=128,
num_classes=7
)
"""# Training"""
def training_model(cfg:any) -> None:
"""
Train the object detection model.
This function initiates the training of the object detection model
using the configured settings.
Parameters
----------
cfg : CfgNode
The configuration object for the object detection model.
Returns
-------
None
"""
trainer = DefaultTrainer(cfg)
trainer.resume_or_load(resume=False)
trainer.train()
return None
training_model(cfg=cfg)
def docutron_config(cfg:any, threshold:np.float256=0.7) -> any:
"""
Create an object detector using the given configuration.
Parameters
----------
cfg : CfgNode
The configuration object for the object detection model.
threshold : float, optional
Threshold used to filter out low-scored bounding boxes predicted by the
Fast R-CNN component of the model during inference/test time.
Returns
-------
docutron : DefaultPredictor
An object detection predictor for inference.
"""
cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, "model_final.pth")
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = threshold
docutron = DefaultPredictor(cfg)
return docutron
docutron = docutron_config(cfg=cfg, threshold=0.7)
"""# Testing"""
def get_catalog(dataset:str) -> any:
"""
Get the dataset dictionaries for the specified dataset name.
Parameters
----------
dataset_val : str
The name of the dataset to retrieve.
Returns
-------
dataset_dicts : list
A list of dataset dictionaries.
"""
dataset_dicts = DatasetCatalog.get(dataset)
return dataset_dicts
catalog = get_catalog(dataset=dataset_val)
def visualize_object_detection(docutron:any, catalog:list, metadata:list, num_samples:int=1) -> None:
"""
Visualize object detection results for random samples from the dataset.
Parameters
----------
docutron : DefaultPredictor
An object detection predictor for inference.
dataset_dicts : list
A list of dataset dictionaries.
metadata : Metadata
Metadata for the dataset.
num_samples : int, optional
The number of samples to visualize. Default is 1.
Returns
-------
None
"""
for _ in range(num_samples):
d = random.choice(catalog)
img = cv2.imread(d["file_name"])
outputs = docutron(img)
v = Visualizer(img[:, :, ::-1], metadata=metadata, scale=1)
out = v.draw_instance_predictions(outputs["instances"].to("cpu"))
cv2_imshow(out.get_image()[:, :, ::-1])
return None
visualize_object_detection(docutron=docutron, catalog=catalog, metadata=metadata, num_samples=1)