From b4654a607a27361f80fe32a2d8f86b43476d87c4 Mon Sep 17 00:00:00 2001 From: yreddy31 Date: Thu, 25 Feb 2021 16:06:10 +0530 Subject: [PATCH] more features! --- setup.py | 2 +- torch_snippets/loader.py | 31 +++++++++++++++++++++++++------ torch_snippets/torch_loader.py | 3 ++- 3 files changed, 28 insertions(+), 8 deletions(-) diff --git a/setup.py b/setup.py index b3b1a94..5716ee0 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ except ImportError: from distutils.core import setup -VERSION = '0.313' +VERSION = '0.314' setup( name = 'torch_snippets', # How you named your package folder (MyLib) packages = ['torch_snippets'], # Chose the same as "name" diff --git a/torch_snippets/loader.py b/torch_snippets/loader.py index cc13689..89422bd 100755 --- a/torch_snippets/loader.py +++ b/torch_snippets/loader.py @@ -357,10 +357,13 @@ def puttext(ax, string, org, size=15, color=(255,0,0), thickness=2): path_effects.Normal()]) def dumpdill(obj, fpath, silent=False): + start = time.time() os.makedirs(parent(fpath), exist_ok=True) with open(fpath, 'wb') as f: dill.dump(obj, f) - if not silent: logger.info('Dumped object @ {}'.format(fpath)) + if not silent: + fsize = os.path.getsize(fpath) >> 20 + logger.info(f'Dumped object of size `~{fsize} MB` @ "{fpath}" in {time.time()-start:.2f} seconds') def loaddill(fpath): with open(fpath, 'rb') as f: @@ -653,15 +656,31 @@ def to_relative(input, shape): bbs = bbfy(input) return [bb.relative((h,w)) for bb in bbs] +def compute_eps(eps): + if isinstance(eps, tuple): + if len(eps) == 4: + epsx, epsy, epsX, epsY = eps + else: + epsx, epsy = eps + epsx, epsy, epsX, epsY = epsx/2, epsy/2, epsx/2, epsy/2 + else: + epsx, epsy, epsX, epsY = eps/2, eps/2, eps/2, eps/2 + return epsx, epsy, epsX, epsY + def enlarge_bbs(bbs, eps=0.2): - "enlarge all `bbs` by `eps` fraction (or eps*100 percent)" - epsx, epsy = eps if isinstance(eps, tuple) else (eps, eps) + "enlarge all `bbs` by `eps` fraction (i.e., eps*100 percent)" + bbs = bbfy(bbs) + epsx, epsy, epsX, epsY = compute_eps(eps) + bbs = bbfy(bbs) shs = [(bb.h,bb.w) for bb in bbs] - return [BB(x-(w*eps/2), y-(h*eps/2), X+(w*eps/2), Y+(h*eps/2))\ + return [BB(x-(w*epsx), y-(h*epsy), X+(w*epsX), Y+(h*epsY))\ for (x,y,X,Y),(h,w) in zip(bbs, shs)] def shrink_bbs(bbs, eps=0.2): - "shrink all `bbs` by `eps` fraction (or eps*100 percent)" + "shrink all `bbs` by `eps` fraction (i.e., eps*100 percent)" + bbs = bbfy(bbs) + epsx, epsy, epsX, epsY = compute_eps(eps) + bbs = bbfy(bbs) shs = [(bb.h,bb.w) for bb in bbs] - return [BB(x+(w*eps/2), y+(h*eps/2), X-(w*eps/2), Y-(h*eps/2))\ + return [BB(x+(w*epsx), y+(h*epsy), X-(w*epsX), Y-(h*epsY))\ for (x,y,X,Y),(h,w) in zip(bbs, shs)] diff --git a/torch_snippets/torch_loader.py b/torch_snippets/torch_loader.py index 926f148..4f995ad 100644 --- a/torch_snippets/torch_loader.py +++ b/torch_snippets/torch_loader.py @@ -176,6 +176,7 @@ def report_metrics(self, pos, **report): print(f'\r{log}{current_iteration}{info(report, self.precision)}{elapsed}', end=end) try: + import pytorch_lightning as pl from pytorch_lightning.callbacks.progress import ProgressBarBase class LightningReport(ProgressBarBase): def __init__(self, epochs, print_every=None, print_total=None, precision=4, old_report=None): @@ -226,7 +227,7 @@ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, def __getattr__(self, attr, **kwargs): return getattr(self.report, attr, **kwargs) - __all__ += ['LightningReport'] + __all__ += ['LightningReport', 'pl'] except: logger.warning('Not importing Lightning Report')