5
5
"""
6
6
from typing import (
7
7
IO , TYPE_CHECKING , AbstractSet , Callable , Collection , Dict , Final , FrozenSet ,
8
- Iterable , List , Mapping , Optional , Set , Tuple , Union ,
8
+ Iterable , List , Mapping , Optional , Set , Tuple , Union , Counter
9
9
)
10
10
from typing_extensions import TypeAlias
11
11
from enum import IntFlag
30
30
'serialise' , 'unserialise' ,
31
31
]
32
32
33
- _fmt_8bit : Final = Struct ('> B' )
34
- _fmt_16bit : Final = Struct ('> H' )
35
- _fmt_32bit : Final = Struct ('> I' )
36
- _fmt_double : Final = Struct ('> d' )
37
- _fmt_header : Final = Struct ('> BI' )
38
- _fmt_ent_header : Final = Struct ('> BBBBBB' )
39
- _fmt_block_pos : Final = Struct ('> IH' )
33
+ _fmt_8bit : Final = Struct ('< B' )
34
+ _fmt_16bit : Final = Struct ('< H' )
35
+ _fmt_32bit : Final = Struct ('< I' )
36
+ _fmt_double : Final = Struct ('< d' )
37
+ _fmt_header : Final = Struct ('< BI' )
38
+ _fmt_ent_header : Final = Struct ('< BBBBBB' )
39
+ _fmt_block_pos : Final = Struct ('< IH' )
40
40
41
41
42
42
# Version number for the format.
43
- BIN_FORMAT_VERSION : Final = 8
43
+ BIN_FORMAT_VERSION : Final = 9
44
44
TAG_EMPTY : Final [FrozenSet [str ]] = frozenset () # This is a singleton.
45
45
# Soft limit on the number of bytes for each block, needs tuning.
46
46
MAX_BLOCK_SIZE : Final = 2048
47
47
# When writing arrays of strings, it's much more efficient to read the whole thing, decode then
48
48
# split by a character rather than read sizes individually.
49
49
STRING_SEP : Final = '\x1F ' # UNIT SEPARATOR
50
+ # Number of strings to keep in the shared database.
51
+ SHARED_STRINGS : Final = 512
50
52
51
53
52
54
class EntFlags (IntFlag ):
@@ -190,20 +192,26 @@ class BinStrDict:
190
192
191
193
Each unique string is assigned a 2-byte index into the list.
192
194
"""
193
- def __init__ (self , database : Iterable [str ]) -> None :
195
+ def __init__ (self , database : Iterable [str ], base : Optional [ 'BinStrDict' ] ) -> None :
194
196
self ._dict : Dict [str , int ] = {
195
197
name : ind for ind , name
196
198
in enumerate (database )
197
199
}
198
- if len (self ._dict ) >= (1 << 16 ):
200
+ # If no base dict, this is for CBaseEntity, so set it to the real dict,
201
+ # so __call__() won't add SHARED_STRINGS to the index.
202
+ self .base_dict : Dict [str , int ] = base ._dict if base is not None else self ._dict
203
+ if len (self ._dict ) + len (self .base_dict ) >= (1 << 16 ):
199
204
raise ValueError ("Too many items in dictionary!" )
200
205
201
206
def __call__ (self , string : str ) -> bytes :
202
207
"""Get the index for a string.
203
208
204
209
The result is the two bytes that represent the string.
205
210
"""
206
- return _fmt_16bit .pack (self ._dict [string ])
211
+ if string in self .base_dict :
212
+ return _fmt_16bit .pack (self .base_dict [string ])
213
+ else :
214
+ return _fmt_16bit .pack (SHARED_STRINGS + self ._dict [string ])
207
215
208
216
def serialise (self , file : IO [bytes ]) -> None :
209
217
"""Convert this to a stream of bytes."""
@@ -220,15 +228,17 @@ def serialise(self, file: IO[bytes]) -> None:
220
228
file .write (data )
221
229
222
230
@classmethod
223
- def unserialise (cls , file : IO [bytes ]) -> Callable [[], str ]:
231
+ def unserialise (cls , file : IO [bytes ], base : List [ str ] ) -> Tuple [ List [ str ], Callable [[], str ] ]:
224
232
"""Read the dictionary from a file.
225
233
226
- This returns a function which reads
234
+ This returns the dict, and a function which reads
227
235
a string from a file at the current point.
228
236
"""
229
237
[length ] = _fmt_16bit .unpack (file .read (2 ))
230
238
inv_list = lzma .decompress (file .read (length )).decode ('utf8' ).split (STRING_SEP )
231
- return make_lookup (file , inv_list )
239
+ # This could branch on the index to avoid the concatenation, but this should be
240
+ # faster, and the dict will only be around temporarily anyway.
241
+ return inv_list , make_lookup (file , base + inv_list )
232
242
233
243
@staticmethod
234
244
def read_tags (file : IO [bytes ], from_dict : Callable [[], str ]) -> FrozenSet [str ]:
@@ -253,9 +263,15 @@ def write_tags(
253
263
254
264
class EngineDB (_EngineDBProto ):
255
265
"""Unserialised database, which will be parsed progressively as required."""
256
- def __init__ (self , ent_map : Dict [str , Union [EntityDef , int ]], unparsed : List [Tuple [Iterable [str ], bytes ]]) -> None :
266
+ def __init__ (
267
+ self ,
268
+ ent_map : Dict [str , Union [EntityDef , int ]],
269
+ base_strings : List [str ],
270
+ unparsed : List [Tuple [Iterable [str ], bytes ]],
271
+ ) -> None :
257
272
self .ent_map = ent_map
258
273
self .unparsed = unparsed
274
+ self .base_strings = base_strings
259
275
self .fgd : Optional [FGD ] = None
260
276
261
277
def get_classnames (self ) -> AbstractSet [str ]:
@@ -296,7 +312,7 @@ def _parse_block(self, index: int) -> None:
296
312
apply_bases = []
297
313
298
314
file = io .BytesIO (data )
299
- from_dict = BinStrDict .unserialise (file )
315
+ _ , from_dict = BinStrDict .unserialise (file , self . base_strings )
300
316
for classname in classes :
301
317
self .ent_map [classname .casefold ()] = ent = ent_unserialise (file , classname , from_dict )
302
318
if ent .bases :
@@ -307,11 +323,12 @@ def _parse_block(self, index: int) -> None:
307
323
self .unparsed [index ] = ((), b'' )
308
324
for ent in apply_bases :
309
325
# Apply bases. This should just be for aliases, which are likely also in this block.
326
+ # Importantly, we've already put those in ent_map, so this won't recurse if they
327
+ # are in our block.
310
328
ent .bases = [
311
329
base if isinstance (base , EntityDef ) else self .get_ent (base )
312
330
for base in ent .bases
313
331
]
314
- ent .bases .append (cbase_entity )
315
332
316
333
def get_fgd (self ) -> FGD :
317
334
"""Parse all the blocks and make an FGD."""
@@ -580,6 +597,25 @@ def record_strings(string: str) -> bytes:
580
597
ent_to_string [ent ] = ent_strings = set ()
581
598
ent_serialise (ent , dummy_file , record_strings )
582
599
ent_to_size [ent ] = dummy_file .tell ()
600
+
601
+ assert ent .classname .casefold () == '_cbaseentity_'
602
+ base_strings = ent_to_string [ent ]
603
+ print (f'{ SHARED_STRINGS - len (base_strings )} /{ SHARED_STRINGS } shared strings used.' )
604
+
605
+ # Find common strings, move them to the base set.
606
+ string_counts = Counter [str ]()
607
+ for strings in ent_to_string .values ():
608
+ string_counts .update (strings )
609
+ # Shared strings might already be in base, so break early once we hit the quota.
610
+ # At most we'll need to add SHARED_STRINGS different items.
611
+ for st , count in string_counts .most_common (SHARED_STRINGS ):
612
+ if len (base_strings ) >= SHARED_STRINGS :
613
+ break
614
+ base_strings .add (st )
615
+ for strings in ent_to_string .values ():
616
+ if strings is not base_strings :
617
+ strings -= base_strings
618
+
583
619
return ent_to_string , ent_to_size
584
620
585
621
@@ -655,8 +691,15 @@ def add_ent(self, ent: EntityDef) -> None:
655
691
all_blocks .sort (key = lambda block : len (block .ents ))
656
692
657
693
for block in all_blocks :
658
- efficency = len (block .stringdb ) / sum (map (len , map (ent_to_string .__getitem__ , block .ents )))
659
- print (f'{ block .bytesize } bytes = { len (block .ents )} = { 1 / efficency :.02%} ' )
694
+ if block .stringdb :
695
+ efficency = format (
696
+ sum (map (len , map (ent_to_string .__getitem__ , block .ents )))
697
+ / len (block .stringdb ),
698
+ '.02%'
699
+ )
700
+ else :
701
+ efficency = 'All shared'
702
+ print (f'{ block .bytesize } bytes = { len (block .ents )} = { efficency } ' )
660
703
print (len (all_blocks ), 'blocks' )
661
704
return [
662
705
(block .ents , block .stringdb )
@@ -676,7 +719,7 @@ def serialise(fgd: FGD, file: IO[bytes]) -> None:
676
719
print ('Computing string sizes...' )
677
720
# We need the database for CBaseEntity, but not to include it with anything else.
678
721
ent_to_string , ent_to_size = compute_ent_strings (itertools .chain (all_ents , [CBaseEntity ]))
679
- CBaseEntity_strings = ent_to_string [CBaseEntity ]
722
+ base_strings = ent_to_string [CBaseEntity ]
680
723
681
724
# For every pair of entities (!), compute the number of overlapping ents.
682
725
print ('Computing overlaps...' )
@@ -708,20 +751,22 @@ def serialise(fgd: FGD, file: IO[bytes]) -> None:
708
751
block_ents .sort (key = operator .attrgetter ('classname' ))
709
752
for ent in block_ents :
710
753
assert '\x1b ' not in ent .classname , ent
711
- classnames = lzma .compress (STRING_SEP .join (ent .classname for ent in block_ents ).encode ('utf8' ))
754
+ # Not worth it to compress these.
755
+ classnames = STRING_SEP .join (ent .classname for ent in block_ents ).encode ('utf8' )
712
756
file .write (_fmt_16bit .pack (len (classnames )))
713
757
file .write (classnames )
714
758
deferred .defer (('block' , id (block_ents )), _fmt_block_pos , write = True )
715
759
716
- # First, write CBaseEntity specially.
717
- dictionary = BinStrDict (CBaseEntity_strings )
718
- dictionary .serialise (file )
719
- ent_serialise (CBaseEntity , file , dictionary )
760
+ # First, write the base strings and CBaseEntity specially.
761
+ assert len (base_strings ) == SHARED_STRINGS , len (base_strings )
762
+ base_dict = BinStrDict (base_strings , None )
763
+ base_dict .serialise (file )
764
+ ent_serialise (CBaseEntity , file , base_dict )
720
765
721
766
# Then write each block and then each entity.
722
767
for block_ents , block_stringdb in blocks :
723
768
block_off = file .tell ()
724
- dictionary = BinStrDict (block_stringdb )
769
+ dictionary = BinStrDict (block_stringdb , base_dict )
725
770
dictionary .serialise (file )
726
771
for ent in block_ents :
727
772
ent_serialise (ent , file , dictionary )
@@ -748,19 +793,19 @@ def unserialise(file: IO[bytes]) -> _EngineDBProto:
748
793
749
794
for block_id in range (block_count ):
750
795
[cls_size ] = _fmt_16bit .unpack (file .read (2 ))
751
- classnames = lzma . decompress ( file .read (cls_size ) ).decode ('utf8' ).split (STRING_SEP )
796
+ classnames = file .read (cls_size ).decode ('utf8' ).split (STRING_SEP )
752
797
block_classnames .append (classnames )
753
798
for name in classnames :
754
799
ent_map [name .casefold ()] = block_id
755
800
off , size = _fmt_block_pos .unpack (file .read (_fmt_block_pos .size ))
756
801
positions .append ((classnames , off , size ))
757
802
758
803
# Read CBaseEntity.
759
- from_dict = BinStrDict .unserialise (file )
804
+ base_strings , from_dict = BinStrDict .unserialise (file , [] )
760
805
ent_map ['_cbaseentity_' ] = ent_unserialise (file , '_CBaseEntity_' , from_dict )
761
806
762
807
for classnames , off , size in positions :
763
808
file .seek (off )
764
809
unparsed .append ((classnames , file .read (size )))
765
810
766
- return EngineDB (ent_map , unparsed )
811
+ return EngineDB (ent_map , base_strings , unparsed )
0 commit comments