[MultiVarStore] Fix caching

This commit is contained in:
Behdad Esfahbod 2023-12-15 15:46:55 -07:00
parent 6bcab786a1
commit 5fe9da49f3

View File

@ -32,7 +32,7 @@ class OnlineMultiVarStoreBuilder(object):
self._supports = None self._supports = None
self._varDataIndices = {} self._varDataIndices = {}
self._varDataCaches = {} self._varDataCaches = {}
self._cache = {} self._cache = None
def setModel(self, model): def setModel(self, model):
self.setSupports(model.supports) self.setSupports(model.supports)
@ -43,7 +43,7 @@ class OnlineMultiVarStoreBuilder(object):
self._supports = list(supports) self._supports = list(supports)
if not self._supports[0]: if not self._supports[0]:
del self._supports[0] # Drop base master support del self._supports[0] # Drop base master support
self._cache = {} self._cache = None
self._data = None self._data = None
def finish(self, optimize=True): def finish(self, optimize=True):
@ -51,7 +51,7 @@ class OnlineMultiVarStoreBuilder(object):
self._store.VarDataCount = len(self._store.VarData) self._store.VarDataCount = len(self._store.VarData)
return self._store return self._store
def _add_MultiVarData(self, num_items=1): def _add_MultiVarData(self):
regionMap = self._regionMap regionMap = self._regionMap
regionList = self._regionList regionList = self._regionList
@ -73,7 +73,7 @@ class OnlineMultiVarStoreBuilder(object):
self._outer = varDataIdx self._outer = varDataIdx
self._data = self._store.VarData[varDataIdx] self._data = self._store.VarData[varDataIdx]
self._cache = self._varDataCaches[key] self._cache = self._varDataCaches[key]
if len(self._data.Item) + num_items > 0xFFFF: if len(self._data.Item) == 0xFFFF:
# This is full. Need new one. # This is full. Need new one.
varDataIdx = None varDataIdx = None
@ -95,12 +95,13 @@ class OnlineMultiVarStoreBuilder(object):
deltas = tuple(round(d) for d in deltas) deltas = tuple(round(d) for d in deltas)
deltas_tuple = tuple(tuple(d) for d in deltas) deltas_tuple = tuple(tuple(d) for d in deltas)
if not self._data:
self._add_MultiVarData()
varIdx = self._cache.get(deltas_tuple) varIdx = self._cache.get(deltas_tuple)
if varIdx is not None: if varIdx is not None:
return varIdx return varIdx
if not self._data:
self._add_MultiVarData()
inner = len(self._data.Item) inner = len(self._data.Item)
if inner == 0xFFFF: if inner == 0xFFFF:
# Full array. Start new one. # Full array. Start new one.