Skip to content

Commit

Permalink
Merge pull request #3 from smoke-trees/docs/export_keras
Browse files Browse the repository at this point in the history
added docstings to export_keras.py file
  • Loading branch information
Geek-ubaid authored Jun 23, 2020
2 parents e85d9a9 + c11d951 commit cd7610d
Showing 1 changed file with 48 additions and 0 deletions.
48 changes: 48 additions & 0 deletions forest_utils/export_keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,33 @@


class ModelFromH5(object):
"""
A class for managing downloads and loading of .h5 models
Parameters
----------
output : str
path to output file for downloading the model (by default it is 'model.h5')
config : str
path to the config file of the model (by default it is 'result.json')
Attributes
----------
base_url : str
This is the bace url for downloading the .h5 model files
url_id
Contains the link to the model extracted from the results.json files
output : str
Relative path to the output file for the model download
Methods
-------
get_complete_url(url)
method to get complete link from the given url
load_model()
download the model .h5 file from the url to output route and returns the loaded keras model
"""

def __init__(self, output = 'model.h5', config = 'result.json'):
super().__init__()

Expand All @@ -17,10 +43,32 @@ def __init__(self, output = 'model.h5', config = 'result.json'):
self.output = output

def get_complete_url(self, url):
"""
method (used internally inside class) to get complete link (including base_url) from the given url
Parameters
----------
url : str
url to split and make complete url from
Returns
-------
link : str
complete url to the model file
"""
split_url = url.split('/')
return self.base_url + split_url[5]

def load_model(self):
"""
method to download model from the url to the output file and load it into keras
Returns
-------
keras model
downloaded model loaded into keras model ready to use!
"""
try:
gdown.download(self.url_id, self.output, quiet = False)
return tf.keras.models.load_model(self.output)
Expand Down

0 comments on commit cd7610d

Please sign in to comment.