From 8dc5342ebad4d56301cb807d51500323b3e1cd9c Mon Sep 17 00:00:00 2001 From: Benjamin Pritchard Date: Sat, 23 Sep 2023 19:45:54 -0400 Subject: [PATCH] tmp --- qcportal/qcportal/dataset_models.py | 52 +++++++++++++++++++++++------ 1 file changed, 42 insertions(+), 10 deletions(-) diff --git a/qcportal/qcportal/dataset_models.py b/qcportal/qcportal/dataset_models.py index 01d338482..8d5f33f84 100644 --- a/qcportal/qcportal/dataset_models.py +++ b/qcportal/qcportal/dataset_models.py @@ -1,7 +1,21 @@ from __future__ import annotations from datetime import datetime -from typing import Optional, Dict, Any, List, Iterable, Type, Tuple, Union, Callable, ClassVar +from typing import ( + TYPE_CHECKING, + Optional, + Dict, + Any, + List, + Iterable, + Type, + Tuple, + Union, + Callable, + ClassVar, + MutableMapping, + Mapping, +) import pandas as pd import pydantic @@ -15,6 +29,9 @@ from qcportal.record_models import PriorityEnum, RecordStatusEnum, BaseRecord from qcportal.utils import make_list, chunk_iterable +if TYPE_CHECKING: + from qcportal.client import PortalClient + class Citation(BaseModel): """A literature citation.""" @@ -85,9 +102,9 @@ class Config: _entry_names: List[str] = PrivateAttr([]) # To be overridden by the derived class with more specific types - _specifications: Dict[str, Any] = PrivateAttr({}) - _entries: Dict[str, Any] = PrivateAttr({}) - _record_map: Dict[Tuple[str, str], Any] = PrivateAttr({}) + _specifications: MutableMapping[str, Any] = PrivateAttr({}) + _entries: MutableMapping[str, Any] = PrivateAttr({}) + _record_map: MutableMapping[Tuple[str, str], Any] = PrivateAttr({}) # Values computed outside QCA _contributed_values: Optional[Dict[str, ContributedValues]] = PrivateAttr(None) @@ -110,7 +127,7 @@ class Config: # Some dataset options auto_fetch_missing: bool = True # Automatically fetch missing records from the server - def __init__(self, client=None, view_data=None, **kwargs): + def __init__(self, client: Optional[PortalClient] = None, view_data=None, **kwargs): BaseModel.__init__(self, **kwargs) # Calls derived class propagate_client @@ -119,6 +136,12 @@ def __init__(self, client=None, view_data=None, **kwargs): assert self._client is client, "Client not set in base dataset class?" + if client and client.cache.enabled: + cache_name = f"dataset_{self.id}" + self._specifications = client.cache.get_cache(cache_name, "specifications", self._specification_type) + self._entries = client.cache.get_cache(cache_name, "entries", self._entry_type) + self._record_map = client.cache.get_cache(cache_name, "record_map", self._record_type, encode_keys=True) + def __init_subclass__(cls): """ Register derived classes for later use @@ -324,7 +347,7 @@ def set_default_priority(self, new_default_priority: PriorityEnum): # Specifications ################################### @property - def specifications(self) -> Dict[str, Any]: + def specifications(self) -> Mapping[str, Any]: if self.is_view: return self._view_data.get_specifications(self._specification_type) else: @@ -346,12 +369,15 @@ def fetch_specifications(self) -> None: self.assert_is_not_view() self.assert_online() - self._specifications = self._client.make_request( + server_specifications = self._client.make_request( "get", f"api/v1/datasets/{self.dataset_type}/{self.id}/specifications", Dict[str, self._specification_type], ) + self._specifications.clear() + self._specifications.update(server_specifications) + def rename_specification(self, old_name: str, new_name: str): self.assert_is_not_view() self.assert_online() @@ -362,10 +388,13 @@ def rename_specification(self, old_name: str, new_name: str): "patch", f"api/v1/datasets/{self.dataset_type}/{self.id}/specifications", None, body=name_map ) - self._specifications = {name_map.get(k, k): v for k, v in self._specifications.items()} + if old_name in self._specifications: + self._specifications[new_name] = self._specifications.pop(old_name) # Renames the specifications in the record map - self._record_map = {(e, name_map.get(s, s)): r for (e, s), r in self._record_map.items()} + rename_keys = [(e,s) for e,s in self._record_map.keys() if s == old_name] + for e,s in rename_keys: + self._record_map[(e,new_name)] = self._record_map.pop((e,old_name)) def delete_specification(self, name: str, delete_records: bool = False) -> DeleteMetadata: self.assert_is_not_view() @@ -382,7 +411,10 @@ def delete_specification(self, name: str, delete_records: bool = False) -> Delet # Delete locally-cached stuff self._specifications.pop(name, None) - self._record_map = {(e, s): r for (e, s), r in self._record_map.items() if s != name} + + delete_keys = [(e,s) for e,s in self._record_map.keys() if s == name] + for e,s in delete_keys: + self._record_map.pop((e,s), None) return ret