1515import os
1616import uuid
1717import shutil
18+ from typed_python .compiler .loaded_module import LoadedModule
19+ from typed_python .compiler .binary_shared_object import BinarySharedObject
1820
19- from typing import Optional , List
20-
21- from typed_python .compiler .binary_shared_object import LoadedBinarySharedObject , BinarySharedObject
22- from typed_python .compiler .directed_graph import DirectedGraph
23- from typed_python .compiler .typed_call_target import TypedCallTarget
2421from typed_python .SerializationContext import SerializationContext
2522from typed_python import Dict , ListOf
2623
@@ -55,173 +52,146 @@ def __init__(self, cacheDir):
5552
5653 ensureDirExists (cacheDir )
5754
58- self .loadedBinarySharedObjects = Dict (str , LoadedBinarySharedObject )()
55+ self .loadedModules = Dict (str , LoadedModule )()
5956 self .nameToModuleHash = Dict (str , str )()
6057
61- self .moduleManifestsLoaded = set ()
58+ self .modulesMarkedValid = set ()
59+ self .modulesMarkedInvalid = set ()
6260
6361 for moduleHash in os .listdir (self .cacheDir ):
6462 if len (moduleHash ) == 40 :
6563 self .loadNameManifestFromStoredModuleByHash (moduleHash )
6664
67- # the set of functions with an associated module in loadedBinarySharedObjects
68- self .targetsLoaded : Dict [str , TypedCallTarget ] = {}
69-
70- # the set of functions with linked and validated globals (i.e. ready to be run).
71- self .targetsValidated = set ()
72-
73- self .function_dependency_graph = DirectedGraph ()
74- # dict from function linkname to list of global names (should be llvm keys in serialisedGlobalDefinitions)
75- self .global_dependencies = Dict (str , ListOf (str ))()
65+ self .targetsLoaded = {}
7666
77- def hasSymbol (self , linkName : str ) -> bool :
78- """NB this will return True even if the linkName is ultimately unretrievable."""
67+ def hasSymbol (self , linkName ):
7968 return linkName in self .nameToModuleHash
8069
81- def getTarget (self , linkName : str ) -> TypedCallTarget :
82- if not self .hasSymbol (linkName ):
83- raise ValueError ( f'symbol not found for linkName { linkName } ' )
70+ def getTarget (self , linkName ) :
71+ assert self .hasSymbol (linkName )
72+
8473 self .loadForSymbol (linkName )
74+
8575 return self .targetsLoaded [linkName ]
8676
87- def dependencies (self , linkName : str ) -> Optional [ List [ str ]] :
88- """Returns all the function names that `linkName` depends on"""
89- return list ( self . function_dependency_graph . outgoing ( linkName ))
77+ def markModuleHashInvalid (self , hashstr ) :
78+ with open ( os . path . join ( self . cacheDir , hashstr , "marked_invalid" ), "w" ):
79+ pass
9080
91- def loadForSymbol (self , linkName : str ) -> None :
92- """Loads the whole module, and any submodules, into LoadedBinarySharedObjects"""
81+ def loadForSymbol (self , linkName ):
9382 moduleHash = self .nameToModuleHash [linkName ]
9483
9584 self .loadModuleByHash (moduleHash )
9685
97- if linkName not in self .targetsValidated :
98- dependantFuncs = self .dependencies (linkName ) + [linkName ]
99- globalsToLink = {} # dict from modulehash to list of globals.
100- for funcName in dependantFuncs :
101- if funcName not in self .targetsValidated :
102- funcModuleHash = self .nameToModuleHash [funcName ]
103- # append to the list of globals to link for a given module. TODO: optimise this, don't double-link.
104- globalsToLink [funcModuleHash ] = globalsToLink .get (funcModuleHash , []) + self .global_dependencies .get (funcName , [])
105-
106- for moduleHash , globs in globalsToLink .items (): # this works because loadModuleByHash loads submodules too.
107- if globs :
108- definitionsToLink = {x : self .loadedBinarySharedObjects [moduleHash ].serializedGlobalVariableDefinitions [x ]
109- for x in globs
110- }
111- self .loadedBinarySharedObjects [moduleHash ].linkGlobalVariables (definitionsToLink )
112- if not self .loadedBinarySharedObjects [moduleHash ].validateGlobalVariables (definitionsToLink ):
113- raise RuntimeError ('failed to validate globals when loading:' , linkName )
114-
115- self .targetsValidated .update (dependantFuncs )
116-
117- def loadModuleByHash (self , moduleHash : str ) -> None :
86+ def loadModuleByHash (self , moduleHash ):
11887 """Load a module by name.
11988
12089 As we load, place all the newly imported typed call targets into
12190 'nameToTypedCallTarget' so that the rest of the system knows what functions
12291 have been uncovered.
12392 """
124- if moduleHash in self .loadedBinarySharedObjects :
125- return
93+ if moduleHash in self .loadedModules :
94+ return True
12695
12796 targetDir = os .path .join (self .cacheDir , moduleHash )
12897
129- # TODO (Will) - store these names as module consts, use one .dat only
130- with open (os .path .join (targetDir , "type_manifest.dat" ), "rb" ) as f :
131- callTargets = SerializationContext ().deserialize (f .read ())
132-
133- with open (os .path .join (targetDir , "globals_manifest.dat" ), "rb" ) as f :
134- serializedGlobalVarDefs = SerializationContext ().deserialize (f .read ())
98+ try :
99+ with open (os .path .join (targetDir , "type_manifest.dat" ), "rb" ) as f :
100+ callTargets = SerializationContext ().deserialize (f .read ())
135101
136- with open (os .path .join (targetDir , "native_type_manifest .dat" ), "rb" ) as f :
137- functionNameToNativeType = SerializationContext ().deserialize (f .read ())
102+ with open (os .path .join (targetDir , "globals_manifest .dat" ), "rb" ) as f :
103+ globalVarDefs = SerializationContext ().deserialize (f .read ())
138104
139- with open (os .path .join (targetDir , "submodules .dat" ), "rb" ) as f :
140- submodules = SerializationContext ().deserialize (f .read (), ListOf ( str ))
105+ with open (os .path .join (targetDir , "native_type_manifest .dat" ), "rb" ) as f :
106+ functionNameToNativeType = SerializationContext ().deserialize (f .read ())
141107
142- with open (os .path .join (targetDir , "function_dependencies.dat" ), "rb" ) as f :
143- dependency_edgelist = SerializationContext ().deserialize (f .read ())
108+ with open (os .path .join (targetDir , "submodules.dat" ), "rb" ) as f :
109+ submodules = SerializationContext ().deserialize (f .read (), ListOf (str ))
110+ except Exception :
111+ self .markModuleHashInvalid (moduleHash )
112+ return False
144113
145- with open (os .path .join (targetDir , "global_dependencies.dat" ), "rb" ) as f :
146- globalDependencies = SerializationContext ().deserialize (f .read ())
114+ if not LoadedModule .validateGlobalVariables (globalVarDefs ):
115+ self .markModuleHashInvalid (moduleHash )
116+ return False
147117
148118 # load the submodules first
149119 for submodule in submodules :
150- self .loadModuleByHash (submodule )
120+ if not self .loadModuleByHash (submodule ):
121+ return False
151122
152123 modulePath = os .path .join (targetDir , "module.so" )
153124
154125 loaded = BinarySharedObject .fromDisk (
155126 modulePath ,
156- serializedGlobalVarDefs ,
157- functionNameToNativeType ,
158- globalDependencies
159-
127+ globalVarDefs ,
128+ functionNameToNativeType
160129 ).loadFromPath (modulePath )
161130
162- self .loadedBinarySharedObjects [moduleHash ] = loaded
131+ self .loadedModules [moduleHash ] = loaded
163132
164133 self .targetsLoaded .update (callTargets )
165134
166- assert not any (key in self .global_dependencies for key in globalDependencies ) # should only happen if there's a hash collision.
167- self .global_dependencies .update (globalDependencies )
168-
169- # update the cache's dependency graph with our new edges.
170- for function_name , dependant_function_name in dependency_edgelist :
171- self .function_dependency_graph .addEdge (source = function_name , dest = dependant_function_name )
135+ return True
172136
173- def addModule (self , binarySharedObject , nameToTypedCallTarget , linkDependencies , dependencyEdgelist ):
137+ def addModule (self , binarySharedObject , nameToTypedCallTarget , linkDependencies ):
174138 """Add new code to the compiler cache.
175139
176140 Args:
177- binarySharedObject: a BinarySharedObject containing the actual assembler
178- we've compiled.
179- nameToTypedCallTarget: a dict from linkname to TypedCallTarget telling us
180- the formal python types for all the objects.
181- linkDependencies: a set of linknames we depend on directly.
182- dependencyEdgelist (list): a list of source, dest pairs giving the set of dependency graph for the
183- module.
141+ binarySharedObject - a BinarySharedObject containing the actual assembler
142+ we've compiled
143+ nameToTypedCallTarget - a dict from linkname to TypedCallTarget telling us
144+ the formal python types for all the objects
145+ linkDependencies - a set of linknames we depend on directly.
184146 """
185147 dependentHashes = set ()
186148
187149 for name in linkDependencies :
188150 dependentHashes .add (self .nameToModuleHash [name ])
189151
190- path , hashToUse = self .writeModuleToDisk (binarySharedObject , nameToTypedCallTarget , dependentHashes , dependencyEdgelist )
152+ path , hashToUse = self .writeModuleToDisk (binarySharedObject , nameToTypedCallTarget , dependentHashes )
191153
192- self .loadedBinarySharedObjects [hashToUse ] = (
154+ self .loadedModules [hashToUse ] = (
193155 binarySharedObject .loadFromPath (os .path .join (path , "module.so" ))
194156 )
195157
196158 for n in binarySharedObject .definedSymbols :
197159 self .nameToModuleHash [n ] = hashToUse
198160
199- # link & validate all globals for the new module
200- self .loadedBinarySharedObjects [hashToUse ].linkGlobalVariables ()
201- if not self .loadedBinarySharedObjects [hashToUse ].validateGlobalVariables (
202- self .loadedBinarySharedObjects [hashToUse ].serializedGlobalVariableDefinitions ):
203- raise RuntimeError ('failed to validate globals in new module:' , hashToUse )
204-
205- def loadNameManifestFromStoredModuleByHash (self , moduleHash ) -> None :
206- if moduleHash in self .moduleManifestsLoaded :
207- return
161+ def loadNameManifestFromStoredModuleByHash (self , moduleHash ):
162+ if moduleHash in self .modulesMarkedValid :
163+ return True
208164
209165 targetDir = os .path .join (self .cacheDir , moduleHash )
210166
167+ # ignore 'marked invalid'
168+ if os .path .exists (os .path .join (targetDir , "marked_invalid" )):
169+ # just bail - don't try to read it now
170+
171+ # for the moment, we don't try to clean up the cache, because
172+ # we can't be sure that some process is not still reading the
173+ # old files.
174+ self .modulesMarkedInvalid .add (moduleHash )
175+ return False
176+
211177 with open (os .path .join (targetDir , "submodules.dat" ), "rb" ) as f :
212178 submodules = SerializationContext ().deserialize (f .read (), ListOf (str ))
213179
214180 for subHash in submodules :
215- self .loadNameManifestFromStoredModuleByHash (subHash )
181+ if not self .loadNameManifestFromStoredModuleByHash (subHash ):
182+ self .markModuleHashInvalid (subHash )
183+ return False
216184
217185 with open (os .path .join (targetDir , "name_manifest.dat" ), "rb" ) as f :
218186 self .nameToModuleHash .update (
219187 SerializationContext ().deserialize (f .read (), Dict (str , str ))
220188 )
221189
222- self .moduleManifestsLoaded .add (moduleHash )
190+ self .modulesMarkedValid .add (moduleHash )
191+
192+ return True
223193
224- def writeModuleToDisk (self , binarySharedObject , nameToTypedCallTarget , submodules , dependencyEdgelist ):
194+ def writeModuleToDisk (self , binarySharedObject , nameToTypedCallTarget , submodules ):
225195 """Write out a disk representation of this module.
226196
227197 This includes writing both the shared object, a manifest of the function names
@@ -274,17 +244,11 @@ def writeModuleToDisk(self, binarySharedObject, nameToTypedCallTarget, submodule
274244
275245 # write the type manifest
276246 with open (os .path .join (tempTargetDir , "globals_manifest.dat" ), "wb" ) as f :
277- f .write (SerializationContext ().serialize (binarySharedObject .serializedGlobalVariableDefinitions ))
247+ f .write (SerializationContext ().serialize (binarySharedObject .globalVariableDefinitions ))
278248
279249 with open (os .path .join (tempTargetDir , "submodules.dat" ), "wb" ) as f :
280250 f .write (SerializationContext ().serialize (ListOf (str )(submodules ), ListOf (str )))
281251
282- with open (os .path .join (tempTargetDir , "function_dependencies.dat" ), "wb" ) as f :
283- f .write (SerializationContext ().serialize (dependencyEdgelist )) # might need a listof
284-
285- with open (os .path .join (tempTargetDir , "global_dependencies.dat" ), "wb" ) as f :
286- f .write (SerializationContext ().serialize (binarySharedObject .globalDependencies ))
287-
288252 try :
289253 os .rename (tempTargetDir , targetDir )
290254 except IOError :
@@ -300,7 +264,7 @@ def function_pointer_by_name(self, linkName):
300264 if moduleHash is None :
301265 raise Exception ("Can't find a module for " + linkName )
302266
303- if moduleHash not in self .loadedBinarySharedObjects :
267+ if moduleHash not in self .loadedModules :
304268 self .loadForSymbol (linkName )
305269
306- return self .loadedBinarySharedObjects [moduleHash ].functionPointers [linkName ]
270+ return self .loadedModules [moduleHash ].functionPointers [linkName ]
0 commit comments