11
11
import copy
12
12
import datetime
13
13
import functools
14
- import os
15
- import os .path
16
14
import sys
17
- from pathlib import PurePath
15
+ from contextlib import contextmanager
16
+ from pathlib import Path , PurePath
18
17
19
18
import asdf
20
19
import numpy as np
@@ -48,6 +47,26 @@ def wrapper(self, *args, **kwargs):
48
47
return wrapper
49
48
50
49
50
+ @contextmanager
51
+ def _temporary_update_filename (datamodel , filename ):
52
+ """
53
+ Context manager to temporarily update the filename of a datamodel so that it
54
+ can be saved with that new file name without changing the current model's filename
55
+ """
56
+ from roman_datamodels .stnode import Filename
57
+
58
+ if "meta" in datamodel ._instance and "filename" in datamodel ._instance .meta :
59
+ old_filename = datamodel ._instance .meta .filename
60
+ datamodel ._instance .meta .filename = Filename (filename )
61
+
62
+ yield
63
+ datamodel ._instance .meta .filename = old_filename
64
+ return
65
+
66
+ yield
67
+ return
68
+
69
+
51
70
class DataModel (abc .ABC ):
52
71
"""Base class for all top level datamodels"""
53
72
@@ -181,17 +200,9 @@ def clone(target, source, deepcopy=False, memo=None):
181
200
target ._ctx = target
182
201
183
202
def save (self , path , dir_path = None , * args , ** kwargs ):
184
- if callable (path ):
185
- path_head , path_tail = os .path .split (path (self .meta .filename ))
186
- else :
187
- path_head , path_tail = os .path .split (path )
188
- base , ext = os .path .splitext (path_tail )
189
- if isinstance (ext , bytes ):
190
- ext = ext .decode (sys .getfilesystemencoding ())
191
-
192
- if dir_path :
193
- path_head = dir_path
194
- output_path = os .path .join (path_head , path_tail )
203
+ path = Path (path (self .meta .filename ) if callable (path ) else path )
204
+ output_path = Path (dir_path ) / path .name if dir_path else path
205
+ ext = path .suffix .decode (sys .getfilesystemencoding ()) if isinstance (path .suffix , bytes ) else path .suffix
195
206
196
207
# TODO: Support gzip-compressed fits
197
208
if ext == ".asdf" :
@@ -206,10 +217,10 @@ def open_asdf(self, init=None, **kwargs):
206
217
return asdf .open (init , ** kwargs ) if isinstance (init , str ) else asdf .AsdfFile (init , ** kwargs )
207
218
208
219
def to_asdf (self , init , * args , ** kwargs ):
209
- with validate .nuke_validation ():
210
- asdffile = self .open_asdf (** kwargs )
211
- asdffile .tree = {"roman" : self ._instance }
212
- asdffile .write_to (init , * args , ** kwargs )
220
+ with validate .nuke_validation (), _temporary_update_filename ( self , Path ( init ). name ) :
221
+ asdf_file = self .open_asdf (** kwargs )
222
+ asdf_file .tree = {"roman" : self ._instance }
223
+ asdf_file .write_to (init , * args , ** kwargs )
213
224
214
225
def get_primary_array_name (self ):
215
226
"""
0 commit comments