Skip to content

Commit

Permalink
tmp
Browse files Browse the repository at this point in the history
  • Loading branch information
bennybp committed Sep 29, 2023
1 parent 5e7779a commit 8dc5342
Showing 1 changed file with 42 additions and 10 deletions.
52 changes: 42 additions & 10 deletions qcportal/qcportal/dataset_models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,21 @@
from __future__ import annotations

from datetime import datetime
from typing import Optional, Dict, Any, List, Iterable, Type, Tuple, Union, Callable, ClassVar
from typing import (
TYPE_CHECKING,
Optional,
Dict,
Any,
List,
Iterable,
Type,
Tuple,
Union,
Callable,
ClassVar,
MutableMapping,
Mapping,
)

import pandas as pd
import pydantic
Expand All @@ -15,6 +29,9 @@
from qcportal.record_models import PriorityEnum, RecordStatusEnum, BaseRecord
from qcportal.utils import make_list, chunk_iterable

if TYPE_CHECKING:
from qcportal.client import PortalClient


class Citation(BaseModel):
"""A literature citation."""
Expand Down Expand Up @@ -85,9 +102,9 @@ class Config:
_entry_names: List[str] = PrivateAttr([])

# To be overridden by the derived class with more specific types
_specifications: Dict[str, Any] = PrivateAttr({})
_entries: Dict[str, Any] = PrivateAttr({})
_record_map: Dict[Tuple[str, str], Any] = PrivateAttr({})
_specifications: MutableMapping[str, Any] = PrivateAttr({})
_entries: MutableMapping[str, Any] = PrivateAttr({})
_record_map: MutableMapping[Tuple[str, str], Any] = PrivateAttr({})

# Values computed outside QCA
_contributed_values: Optional[Dict[str, ContributedValues]] = PrivateAttr(None)
Expand All @@ -110,7 +127,7 @@ class Config:
# Some dataset options
auto_fetch_missing: bool = True # Automatically fetch missing records from the server

def __init__(self, client=None, view_data=None, **kwargs):
def __init__(self, client: Optional[PortalClient] = None, view_data=None, **kwargs):
BaseModel.__init__(self, **kwargs)

# Calls derived class propagate_client
Expand All @@ -119,6 +136,12 @@ def __init__(self, client=None, view_data=None, **kwargs):

assert self._client is client, "Client not set in base dataset class?"

if client and client.cache.enabled:
cache_name = f"dataset_{self.id}"
self._specifications = client.cache.get_cache(cache_name, "specifications", self._specification_type)
self._entries = client.cache.get_cache(cache_name, "entries", self._entry_type)
self._record_map = client.cache.get_cache(cache_name, "record_map", self._record_type, encode_keys=True)

def __init_subclass__(cls):
"""
Register derived classes for later use
Expand Down Expand Up @@ -324,7 +347,7 @@ def set_default_priority(self, new_default_priority: PriorityEnum):
# Specifications
###################################
@property
def specifications(self) -> Dict[str, Any]:
def specifications(self) -> Mapping[str, Any]:
if self.is_view:
return self._view_data.get_specifications(self._specification_type)
else:
Expand All @@ -346,12 +369,15 @@ def fetch_specifications(self) -> None:
self.assert_is_not_view()
self.assert_online()

self._specifications = self._client.make_request(
server_specifications = self._client.make_request(
"get",
f"api/v1/datasets/{self.dataset_type}/{self.id}/specifications",
Dict[str, self._specification_type],
)

self._specifications.clear()
self._specifications.update(server_specifications)

def rename_specification(self, old_name: str, new_name: str):
self.assert_is_not_view()
self.assert_online()
Expand All @@ -362,10 +388,13 @@ def rename_specification(self, old_name: str, new_name: str):
"patch", f"api/v1/datasets/{self.dataset_type}/{self.id}/specifications", None, body=name_map
)

self._specifications = {name_map.get(k, k): v for k, v in self._specifications.items()}
if old_name in self._specifications:
self._specifications[new_name] = self._specifications.pop(old_name)

# Renames the specifications in the record map
self._record_map = {(e, name_map.get(s, s)): r for (e, s), r in self._record_map.items()}
rename_keys = [(e,s) for e,s in self._record_map.keys() if s == old_name]
for e,s in rename_keys:
self._record_map[(e,new_name)] = self._record_map.pop((e,old_name))

def delete_specification(self, name: str, delete_records: bool = False) -> DeleteMetadata:
self.assert_is_not_view()
Expand All @@ -382,7 +411,10 @@ def delete_specification(self, name: str, delete_records: bool = False) -> Delet

# Delete locally-cached stuff
self._specifications.pop(name, None)
self._record_map = {(e, s): r for (e, s), r in self._record_map.items() if s != name}

delete_keys = [(e,s) for e,s in self._record_map.keys() if s == name]
for e,s in delete_keys:
self._record_map.pop((e,s), None)

return ret

Expand Down

0 comments on commit 8dc5342

Please sign in to comment.