Skip to content

Commit

Permalink
Add caching object to client
Browse files Browse the repository at this point in the history
  • Loading branch information
bennybp committed Sep 29, 2023
1 parent a60b143 commit 986cb57
Showing 1 changed file with 9 additions and 69 deletions.
78 changes: 9 additions & 69 deletions qcportal/qcportal/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
is_valid_groupname,
)
from .base_models import CommonBulkGetNamesBody, CommonBulkGetBody
from .cache import PortalCache
from .client_base import PortalClientBase
from .dataset_models import (
BaseDataset,
Expand Down Expand Up @@ -116,6 +117,9 @@ def __init__(
password: Optional[str] = None,
verify: bool = True,
show_motd: bool = True,
*,
cache_dir: Optional[str] = None,
cache_max_size: int = 0,
) -> None:
"""
Parameters
Expand All @@ -133,10 +137,14 @@ def __init__(
SSL keys.
show_motd
If a Message-of-the-Day is available, display it
cache_dir
Directory to store an internal cache of records and other data
cache_max_size
Maximum size of the cache directory
"""

PortalClientBase.__init__(self, address, username, password, verify, show_motd)
# self._cache = PortalCache(self, cachedir=cache, max_memcache_size=max_memcache_size)
self.cache = PortalCache(address, cache_dir, cache_max_size)

def __repr__(self) -> str:
"""A short representation of the current PortalClient.
Expand Down Expand Up @@ -165,74 +173,6 @@ def _repr_html_(self) -> str:
# postprocess due to raw spacing above
return "\n".join([substr.strip() for substr in output.split("\n")])

# @property
# def cache(self):
# if self._cache.cachedir is not None:
# return os.path.relpath(self._cache.cachedir)
# else:
# return None

# TODO - reimplement
# def _get_with_cache(self, func, id, missing_ok, entity_type, include=None):
# str_id = make_str(id)
# ids = make_list(str_id)

# # pass through the cache first
# # remove any ids that were found in cache
# # if `include` filters passed, don't use cache, just query DB, as it's often faster
# # for a few fields
# if include is None:
# cached = self._cache.get(ids, entity_type=entity_type)
# else:
# cached = {}

# for i in cached:
# ids.remove(i)

# # if all ids found in cache, no need to go further
# if len(ids) == 0:
# if isinstance(id, list):
# return [cached[i] for i in str_id]
# else:
# return cached[str_id]

# # molecule getting does *not* support "include"
# if include is None:
# payload = {
# "data": {"ids": ids},
# }
# else:
# if "ids" not in include:
# include.append("ids")

# payload = {
# "meta": {"includes": include},
# "data": {"ids": ids},
# }

# results, to_cache = func(payload)

# # we only cache if no field filtering was done
# if include is None:
# self._cache.put(to_cache, entity_type=entity_type)

# # combine cached records with queried results
# results.update(cached)

# # check that we have results for all ids asked for
# missing = set(make_list(str_id)) - set(results.keys())

# if missing and not missing_ok:
# raise KeyError(f"No objects found for `id`: {missing}")

# # order the results by input id list
# if isinstance(id, list):
# ordered = [results.get(i, None) for i in str_id]
# else:
# ordered = results.get(str_id, None)

# return ordered

def get_server_information(self) -> Dict[str, Any]:
"""Request general information about the server
Expand Down

0 comments on commit 986cb57

Please sign in to comment.