Skip to content

Commit a0a3a05

Browse files
committed
Share common strings in the engine database to save some space.
1 parent acf1857 commit a0a3a05

File tree

3 files changed

+76
-30
lines changed

3 files changed

+76
-30
lines changed

docs/source/changelog.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ Version (dev)
1717
* Allow entities to delete multiple keyvalues at once.
1818
* Fix silent buttons trying to pack invalid `Buttons.snd0` soundscripts.
1919
* Handle entities being added/removed during iteration of :py:meth:`VMF.search() <srctools.vmf.VMF.search>`.
20+
* Share common strings in the engine database to save some space.
2021

2122
-------------
2223
Version 2.4.1

src/srctools/_engine_db.py

Lines changed: 75 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
"""
66
from typing import (
77
IO, TYPE_CHECKING, AbstractSet, Callable, Collection, Dict, Final, FrozenSet,
8-
Iterable, List, Mapping, Optional, Set, Tuple, Union,
8+
Iterable, List, Mapping, Optional, Set, Tuple, Union, Counter
99
)
1010
from typing_extensions import TypeAlias
1111
from enum import IntFlag
@@ -30,23 +30,25 @@
3030
'serialise', 'unserialise',
3131
]
3232

33-
_fmt_8bit: Final = Struct('>B')
34-
_fmt_16bit: Final = Struct('>H')
35-
_fmt_32bit: Final = Struct('>I')
36-
_fmt_double: Final = Struct('>d')
37-
_fmt_header: Final = Struct('>BI')
38-
_fmt_ent_header: Final = Struct('>BBBBBB')
39-
_fmt_block_pos: Final = Struct('>IH')
33+
_fmt_8bit: Final = Struct('<B')
34+
_fmt_16bit: Final = Struct('<H')
35+
_fmt_32bit: Final = Struct('<I')
36+
_fmt_double: Final = Struct('<d')
37+
_fmt_header: Final = Struct('<BI')
38+
_fmt_ent_header: Final = Struct('<BBBBBB')
39+
_fmt_block_pos: Final = Struct('<IH')
4040

4141

4242
# Version number for the format.
43-
BIN_FORMAT_VERSION: Final = 8
43+
BIN_FORMAT_VERSION: Final = 9
4444
TAG_EMPTY: Final[FrozenSet[str]] = frozenset() # This is a singleton.
4545
# Soft limit on the number of bytes for each block, needs tuning.
4646
MAX_BLOCK_SIZE: Final = 2048
4747
# When writing arrays of strings, it's much more efficient to read the whole thing, decode then
4848
# split by a character rather than read sizes individually.
4949
STRING_SEP: Final = '\x1F' # UNIT SEPARATOR
50+
# Number of strings to keep in the shared database.
51+
SHARED_STRINGS: Final = 512
5052

5153

5254
class EntFlags(IntFlag):
@@ -190,20 +192,26 @@ class BinStrDict:
190192
191193
Each unique string is assigned a 2-byte index into the list.
192194
"""
193-
def __init__(self, database: Iterable[str]) -> None:
195+
def __init__(self, database: Iterable[str], base: Optional['BinStrDict']) -> None:
194196
self._dict: Dict[str, int] = {
195197
name: ind for ind, name
196198
in enumerate(database)
197199
}
198-
if len(self._dict) >= (1 << 16):
200+
# If no base dict, this is for CBaseEntity, so set it to the real dict,
201+
# so __call__() won't add SHARED_STRINGS to the index.
202+
self.base_dict: Dict[str, int] = base._dict if base is not None else self._dict
203+
if len(self._dict) + len(self.base_dict) >= (1 << 16):
199204
raise ValueError("Too many items in dictionary!")
200205

201206
def __call__(self, string: str) -> bytes:
202207
"""Get the index for a string.
203208
204209
The result is the two bytes that represent the string.
205210
"""
206-
return _fmt_16bit.pack(self._dict[string])
211+
if string in self.base_dict:
212+
return _fmt_16bit.pack(self.base_dict[string])
213+
else:
214+
return _fmt_16bit.pack(SHARED_STRINGS + self._dict[string])
207215

208216
def serialise(self, file: IO[bytes]) -> None:
209217
"""Convert this to a stream of bytes."""
@@ -220,15 +228,17 @@ def serialise(self, file: IO[bytes]) -> None:
220228
file.write(data)
221229

222230
@classmethod
223-
def unserialise(cls, file: IO[bytes]) -> Callable[[], str]:
231+
def unserialise(cls, file: IO[bytes], base: List[str]) -> Tuple[List[str], Callable[[], str]]:
224232
"""Read the dictionary from a file.
225233
226-
This returns a function which reads
234+
This returns the dict, and a function which reads
227235
a string from a file at the current point.
228236
"""
229237
[length] = _fmt_16bit.unpack(file.read(2))
230238
inv_list = lzma.decompress(file.read(length)).decode('utf8').split(STRING_SEP)
231-
return make_lookup(file, inv_list)
239+
# This could branch on the index to avoid the concatenation, but this should be
240+
# faster, and the dict will only be around temporarily anyway.
241+
return inv_list, make_lookup(file, base + inv_list)
232242

233243
@staticmethod
234244
def read_tags(file: IO[bytes], from_dict: Callable[[], str]) -> FrozenSet[str]:
@@ -253,9 +263,15 @@ def write_tags(
253263

254264
class EngineDB(_EngineDBProto):
255265
"""Unserialised database, which will be parsed progressively as required."""
256-
def __init__(self, ent_map: Dict[str, Union[EntityDef, int]], unparsed: List[Tuple[Iterable[str], bytes]]) -> None:
266+
def __init__(
267+
self,
268+
ent_map: Dict[str, Union[EntityDef, int]],
269+
base_strings: List[str],
270+
unparsed: List[Tuple[Iterable[str], bytes]],
271+
) -> None:
257272
self.ent_map = ent_map
258273
self.unparsed = unparsed
274+
self.base_strings = base_strings
259275
self.fgd: Optional[FGD] = None
260276

261277
def get_classnames(self) -> AbstractSet[str]:
@@ -296,7 +312,7 @@ def _parse_block(self, index: int) -> None:
296312
apply_bases = []
297313

298314
file = io.BytesIO(data)
299-
from_dict = BinStrDict.unserialise(file)
315+
_, from_dict = BinStrDict.unserialise(file, self.base_strings)
300316
for classname in classes:
301317
self.ent_map[classname.casefold()] = ent = ent_unserialise(file, classname, from_dict)
302318
if ent.bases:
@@ -307,11 +323,12 @@ def _parse_block(self, index: int) -> None:
307323
self.unparsed[index] = ((), b'')
308324
for ent in apply_bases:
309325
# Apply bases. This should just be for aliases, which are likely also in this block.
326+
# Importantly, we've already put those in ent_map, so this won't recurse if they
327+
# are in our block.
310328
ent.bases = [
311329
base if isinstance(base, EntityDef) else self.get_ent(base)
312330
for base in ent.bases
313331
]
314-
ent.bases.append(cbase_entity)
315332

316333
def get_fgd(self) -> FGD:
317334
"""Parse all the blocks and make an FGD."""
@@ -580,6 +597,25 @@ def record_strings(string: str) -> bytes:
580597
ent_to_string[ent] = ent_strings = set()
581598
ent_serialise(ent, dummy_file, record_strings)
582599
ent_to_size[ent] = dummy_file.tell()
600+
601+
assert ent.classname.casefold() == '_cbaseentity_'
602+
base_strings = ent_to_string[ent]
603+
print(f'{SHARED_STRINGS-len(base_strings)}/{SHARED_STRINGS} shared strings used.')
604+
605+
# Find common strings, move them to the base set.
606+
string_counts = Counter[str]()
607+
for strings in ent_to_string.values():
608+
string_counts.update(strings)
609+
# Shared strings might already be in base, so break early once we hit the quota.
610+
# At most we'll need to add SHARED_STRINGS different items.
611+
for st, count in string_counts.most_common(SHARED_STRINGS):
612+
if len(base_strings) >= SHARED_STRINGS:
613+
break
614+
base_strings.add(st)
615+
for strings in ent_to_string.values():
616+
if strings is not base_strings:
617+
strings -= base_strings
618+
583619
return ent_to_string, ent_to_size
584620

585621

@@ -655,8 +691,15 @@ def add_ent(self, ent: EntityDef) -> None:
655691
all_blocks.sort(key=lambda block: len(block.ents))
656692

657693
for block in all_blocks:
658-
efficency = len(block.stringdb) / sum(map(len, map(ent_to_string.__getitem__, block.ents)))
659-
print(f'{block.bytesize} bytes = {len(block.ents)} = {1/efficency:.02%}')
694+
if block.stringdb:
695+
efficency = format(
696+
sum(map(len, map(ent_to_string.__getitem__, block.ents)))
697+
/ len(block.stringdb),
698+
'.02%'
699+
)
700+
else:
701+
efficency = 'All shared'
702+
print(f'{block.bytesize} bytes = {len(block.ents)} = {efficency}')
660703
print(len(all_blocks), 'blocks')
661704
return [
662705
(block.ents, block.stringdb)
@@ -676,7 +719,7 @@ def serialise(fgd: FGD, file: IO[bytes]) -> None:
676719
print('Computing string sizes...')
677720
# We need the database for CBaseEntity, but not to include it with anything else.
678721
ent_to_string, ent_to_size = compute_ent_strings(itertools.chain(all_ents, [CBaseEntity]))
679-
CBaseEntity_strings = ent_to_string[CBaseEntity]
722+
base_strings = ent_to_string[CBaseEntity]
680723

681724
# For every pair of entities (!), compute the number of overlapping ents.
682725
print('Computing overlaps...')
@@ -708,20 +751,22 @@ def serialise(fgd: FGD, file: IO[bytes]) -> None:
708751
block_ents.sort(key=operator.attrgetter('classname'))
709752
for ent in block_ents:
710753
assert '\x1b' not in ent.classname, ent
711-
classnames = lzma.compress(STRING_SEP.join(ent.classname for ent in block_ents).encode('utf8'))
754+
# Not worth it to compress these.
755+
classnames = STRING_SEP.join(ent.classname for ent in block_ents).encode('utf8')
712756
file.write(_fmt_16bit.pack(len(classnames)))
713757
file.write(classnames)
714758
deferred.defer(('block', id(block_ents)), _fmt_block_pos, write=True)
715759

716-
# First, write CBaseEntity specially.
717-
dictionary = BinStrDict(CBaseEntity_strings)
718-
dictionary.serialise(file)
719-
ent_serialise(CBaseEntity, file, dictionary)
760+
# First, write the base strings and CBaseEntity specially.
761+
assert len(base_strings) == SHARED_STRINGS, len(base_strings)
762+
base_dict = BinStrDict(base_strings, None)
763+
base_dict.serialise(file)
764+
ent_serialise(CBaseEntity, file, base_dict)
720765

721766
# Then write each block and then each entity.
722767
for block_ents, block_stringdb in blocks:
723768
block_off = file.tell()
724-
dictionary = BinStrDict(block_stringdb)
769+
dictionary = BinStrDict(block_stringdb, base_dict)
725770
dictionary.serialise(file)
726771
for ent in block_ents:
727772
ent_serialise(ent, file, dictionary)
@@ -748,19 +793,19 @@ def unserialise(file: IO[bytes]) -> _EngineDBProto:
748793

749794
for block_id in range(block_count):
750795
[cls_size] = _fmt_16bit.unpack(file.read(2))
751-
classnames = lzma.decompress(file.read(cls_size)).decode('utf8').split(STRING_SEP)
796+
classnames = file.read(cls_size).decode('utf8').split(STRING_SEP)
752797
block_classnames.append(classnames)
753798
for name in classnames:
754799
ent_map[name.casefold()] = block_id
755800
off, size = _fmt_block_pos.unpack(file.read(_fmt_block_pos.size))
756801
positions.append((classnames, off, size))
757802

758803
# Read CBaseEntity.
759-
from_dict = BinStrDict.unserialise(file)
804+
base_strings, from_dict = BinStrDict.unserialise(file, [])
760805
ent_map['_cbaseentity_'] = ent_unserialise(file, '_CBaseEntity_', from_dict)
761806

762807
for classnames, off, size in positions:
763808
file.seek(off)
764809
unparsed.append((classnames, file.read(size)))
765810

766-
return EngineDB(ent_map, unparsed)
811+
return EngineDB(ent_map, base_strings, unparsed)

src/srctools/fgd.lzma

-133 KB
Binary file not shown.

0 commit comments

Comments
 (0)