Skip to content

Commit

Permalink
Fix torch scheduler import
Browse files Browse the repository at this point in the history
  • Loading branch information
jrrodri committed Nov 25, 2023
1 parent 3670a46 commit ed40176
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 6 deletions.
2 changes: 1 addition & 1 deletion abraia/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def train_model(model, dataloaders, criterion=None, optimizer=None, scheduler=No
# Observe that only parameters of final layer are being optimized as opposed to before.
optimizer = optimizer or torch.optim.SGD(model.fc.parameters(), lr=0.001, momentum=0.9)
# Decay LR by a factor of 0.1 every 7 epochs
scheduler = torch.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
scheduler = scheduler or torch.optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

since = time.time()

Expand Down
2 changes: 0 additions & 2 deletions notebooks/torch_training.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,6 @@
}
],
"source": [
"import os\n",
"import tempfile\n",
"import numpy as np\n",
"from abraia.torch import Dataset, visualize_data, create_model, train_model, visualize_model, save_model, save_classes\n",
"\n",
Expand Down
4 changes: 2 additions & 2 deletions scripts/abraia
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def input_files(src):


@click.group('abraia')
@click.version_option('0.12.2')
@click.version_option('0.12.3')
def cli():
"""Abraia CLI tool"""
pass
Expand All @@ -64,7 +64,7 @@ def configure():
@cli.command()
def info():
"""Show user account information"""
click.echo('abraia, version 0.12.2\n')
click.echo('abraia, version 0.12.3\n')
click.echo('Go to [' + click.style('https://abraia.me/console/', fg='green') + '] to see your account information\n')


Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

setup(
name='abraia',
version='0.12.2',
version='0.12.3',
description='Abraia Multiple SDK',
long_description=long_description,
long_description_content_type='text/markdown',
Expand Down

0 comments on commit ed40176

Please sign in to comment.