@@ -56,37 +56,42 @@ def initialize(self):
5656 self .models [name ] = FileData (real_hash , local_data_path )
5757
5858 def get_model (self , model_name ):
59- if model_name in self .models :
60- return self .models [model_name ].file_path
61- elif len ([optional for optional in self .manifest_data ['optional_files' ] if
62- optional ['name' ] == model_name ]) > 0 :
63- self .find_optional_model (model_name )
64- return self .models [model_name ].file_path
59+ if self .available ():
60+ if model_name in self .models :
61+ return self .models [model_name ].file_path
62+ elif len ([optional for optional in self .manifest_data ['optional_files' ] if
63+ optional ['name' ] == model_name ]) > 0 :
64+ self .find_optional_model (model_name )
65+ return self .models [model_name ].file_path
66+ else :
67+ raise Exception ("model name " + model_name + " not found in manifest" )
6568 else :
66- raise Exception ("model name " + model_name + " not found in manifest" )
69+ raise Exception ("unable to get model {}, model_manifest.json not found." . format ( model_name ) )
6770
6871 def find_optional_model (self , file_name ):
69-
70- found_models = [optional for optional in self .manifest_data ['optional_files' ] if
71- optional ['name' ] == file_name ]
72- if len (found_models ) == 0 :
73- raise Exception ("file with name '" + file_name + "' not found in model manifest." )
74- model_info = found_models [0 ]
75- self .models [file_name ] = {}
76- source_uri = model_info ['source_uri' ]
77- fail_on_tamper = model_info .get ("fail_on_tamper" , False )
78- expected_hash = model_info .get ('md5_checksum' , None )
79- with self .client .file (source_uri ).getFile () as f :
80- local_data_path = f .name
81- real_hash = md5_for_file (local_data_path )
82- if self .using_frozen :
83- if real_hash != expected_hash and fail_on_tamper :
84- raise Exception ("Model File Mismatch for " + file_name +
85- "\n expected hash: " + expected_hash + "\n real hash: " + real_hash )
72+ if self .available ():
73+ found_models = [optional for optional in self .manifest_data ['optional_files' ] if
74+ optional ['name' ] == file_name ]
75+ if len (found_models ) == 0 :
76+ raise Exception ("file with name '" + file_name + "' not found in model manifest." )
77+ model_info = found_models [0 ]
78+ self .models [file_name ] = {}
79+ source_uri = model_info ['source_uri' ]
80+ fail_on_tamper = model_info .get ("fail_on_tamper" , False )
81+ expected_hash = model_info .get ('md5_checksum' , None )
82+ with self .client .file (source_uri ).getFile () as f :
83+ local_data_path = f .name
84+ real_hash = md5_for_file (local_data_path )
85+ if self .using_frozen :
86+ if real_hash != expected_hash and fail_on_tamper :
87+ raise Exception ("Model File Mismatch for " + file_name +
88+ "\n expected hash: " + expected_hash + "\n real hash: " + real_hash )
89+ else :
90+ self .models [file_name ] = FileData (real_hash , local_data_path )
8691 else :
8792 self .models [file_name ] = FileData (real_hash , local_data_path )
8893 else :
89- self . models [ file_name ] = FileData ( real_hash , local_data_path )
94+ raise Exception ( "unable to get model {}, model_manifest.json not found." . format ( model_name ) )
9095
9196 def get_manifest (self ):
9297 if os .path .exists (self .manifest_frozen_path ):
0 commit comments