Skip to content

Commit

Permalink
pull develop
Browse files Browse the repository at this point in the history
  • Loading branch information
mathleur committed Jan 25, 2024
2 parents c81c7f8 + b95c500 commit 9b14fcb
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 38 deletions.
91 changes: 58 additions & 33 deletions polytope/datacube/backends/fdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def __init__(self, config=None, axis_options=None):
self.fake_axes = []
self.unwanted_path = {}
self.nearest_search = {}
self.nearest_search = {}

partial_request = config
# Find values in the level 3 FDB datacube
Expand All @@ -51,7 +52,14 @@ def __init__(self, config=None, axis_options=None):

logging.info("Polytope created axes for: " + str(self._axes.keys()))

def get(self, requests: IndexTree, leaf_path=None):
def get(self, requests: IndexTree):
fdb_requests = []
fdb_requests_decoding_info = []
self.get_fdb_requests(requests, fdb_requests, fdb_requests_decoding_info)
output_values = self.fdb.extract(fdb_requests)
self.assign_fdb_output_to_nodes(output_values, fdb_requests_decoding_info)

def get_fdb_requests(self, requests: IndexTree, fdb_requests=[], fdb_requests_decoding_info=[], leaf_path=None):
if leaf_path is None:
leaf_path = {}

Expand All @@ -60,7 +68,7 @@ def get(self, requests: IndexTree, leaf_path=None):
logging.info("Looking for data for the tree: " + str([leaf.flatten() for leaf in requests.leaves]))

for c in requests.children:
self.get(c)
self.get_fdb_requests(c, fdb_requests, fdb_requests_decoding_info)
# If request node has no children, we have a leaf so need to assign fdb values to it
else:
key_value_path = {requests.axis.name: requests.value}
Expand All @@ -71,12 +79,23 @@ def get(self, requests: IndexTree, leaf_path=None):
leaf_path.update(key_value_path)
if len(requests.children[0].children[0].children) == 0:
# remap this last key
self.get_2nd_last_values(requests, leaf_path)
# TODO: here, find the fdb_requests and associated nodes to which to add results

(path, range_lengths, current_start_idxs, fdb_node_ranges, lat_length) = self.get_2nd_last_values(
requests, leaf_path
)
(original_indices, sorted_request_ranges) = self.sort_fdb_request_ranges(
range_lengths, current_start_idxs, lat_length
)
fdb_requests.append(tuple((path, sorted_request_ranges)))
fdb_requests_decoding_info.append(
tuple((original_indices, fdb_node_ranges, lat_length, range_lengths, current_start_idxs))
)

# Otherwise remap the path for this key and iterate again over children
else:
for c in requests.children:
self.get(c, leaf_path)
self.get_fdb_requests(c, fdb_requests, fdb_requests_decoding_info, leaf_path)

def get_2nd_last_values(self, requests, leaf_path=None):
if leaf_path is None:
Expand All @@ -87,6 +106,7 @@ def get_2nd_last_values(self, requests, leaf_path=None):
if len(self.nearest_search) != 0:
first_ax_name = requests.children[0].axis.name
second_ax_name = requests.children[0].children[0].axis.name
# TODO: throw error if first_ax_name or second_ax_name not in self.nearest_search.keys()
nearest_pts = [
[lat_val, lon_val]
for (lat_val, lon_val) in zip(
Expand All @@ -113,7 +133,10 @@ def get_2nd_last_values(self, requests, leaf_path=None):
lat_child.remove_branch()
else:
possible_lons = [latlon[1] for latlon in nearest_latlons if latlon[0] == lat_child.value]
for lon_child in lat_child.children:
lon_children_values = [child.value for child in lat_child.children]
for j in range(len(lon_children_values)):
lon_child_val = lon_children_values[j]
lon_child = [child for child in lat_child.children if child.value == lon_child_val][0]
if lon_child.value not in possible_lons:
lon_child.remove_branch()

Expand All @@ -139,7 +162,10 @@ def get_2nd_last_values(self, requests, leaf_path=None):
(range_lengths[i], current_start_idxs[i], fdb_node_ranges[i]) = self.get_last_layer_before_leaf(
lat_child, leaf_path, range_length, current_start_idx, fdb_range_nodes
)
self.give_fdb_val_to_node(leaf_path, range_lengths, current_start_idxs, fdb_node_ranges, lat_length)
# TODO: do we need to return all of this?
leaf_path_copy = deepcopy(leaf_path)
leaf_path_copy.pop("values")
return (leaf_path_copy, range_lengths, current_start_idxs, fdb_node_ranges, lat_length)

def get_last_layer_before_leaf(self, requests, leaf_path, range_l, current_idx, fdb_range_n):
i = 0
Expand Down Expand Up @@ -171,27 +197,31 @@ def get_last_layer_before_leaf(self, requests, leaf_path, range_l, current_idx,
current_idx[i] = current_start_idx
return (range_l, current_idx, fdb_range_n)

def give_fdb_val_to_node(self, leaf_path, range_lengths, current_start_idx, fdb_range_nodes, lat_length):
(output_values, original_indices) = self.find_fdb_values(
leaf_path, range_lengths, current_start_idx, lat_length
)
new_fdb_range_nodes = []
new_range_lengths = []
for j in range(lat_length):
for i in range(len(range_lengths[j])):
if current_start_idx[j][i] is not None:
new_fdb_range_nodes.append(fdb_range_nodes[j][i])
new_range_lengths.append(range_lengths[j][i])
sorted_fdb_range_nodes = [new_fdb_range_nodes[i] for i in original_indices]
sorted_range_lengths = [new_range_lengths[i] for i in original_indices]
for i in range(len(sorted_fdb_range_nodes)):
for k in range(sorted_range_lengths[i]):
n = sorted_fdb_range_nodes[i][k]
n.result = output_values[0][0][i][0][k]

def find_fdb_values(self, path, range_lengths, current_start_idx, lat_length):
path.pop("values")
fdb_requests = []
def assign_fdb_output_to_nodes(self, output_values, fdb_requests_decoding_info):
for k in range(len(output_values)):
request_output_values = output_values[k]
(
original_indices,
fdb_node_ranges,
lat_length,
range_lengths,
current_start_idxs,
) = fdb_requests_decoding_info[k]
new_fdb_range_nodes = []
new_range_lengths = []
for j in range(lat_length):
for i in range(len(range_lengths[j])):
if current_start_idxs[j][i] is not None:
new_fdb_range_nodes.append(fdb_node_ranges[j][i])
new_range_lengths.append(range_lengths[j][i])
sorted_fdb_range_nodes = [new_fdb_range_nodes[i] for i in original_indices]
sorted_range_lengths = [new_range_lengths[i] for i in original_indices]
for i in range(len(sorted_fdb_range_nodes)):
for j in range(sorted_range_lengths[i]):
n = sorted_fdb_range_nodes[i][j]
n.result = request_output_values[0][i][0][j]

def sort_fdb_request_ranges(self, range_lengths, current_start_idx, lat_length):
interm_request_ranges = []
for i in range(lat_length):
for j in range(len(range_lengths[i])):
Expand All @@ -201,12 +231,7 @@ def find_fdb_values(self, path, range_lengths, current_start_idx, lat_length):
request_ranges_with_idx = list(enumerate(interm_request_ranges))
sorted_list = sorted(request_ranges_with_idx, key=lambda x: x[1][0])
original_indices, sorted_request_ranges = zip(*sorted_list)
fdb_requests.append(tuple((path, sorted_request_ranges)))
print("REQUEST TO FDB")
print(fdb_requests)
output_values = self.fdb.extract(fdb_requests)
print(output_values)
return (output_values, original_indices)
return (original_indices, sorted_request_ranges)

def datacube_natural_indexes(self, axis, subarray):
indexes = subarray[axis.name]
Expand Down
12 changes: 7 additions & 5 deletions polytope/datacube/index_tree.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import json
import logging
from typing import OrderedDict

from sortedcontainers import SortedList

from .datacube_axis import IntDatacubeAxis
from .datacube_axis import IntDatacubeAxis, UnsliceableDatacubeAxis


class DatacubePath(OrderedDict):
Expand Down Expand Up @@ -85,7 +86,7 @@ def __eq__(self, other):
if other.value == self.value:
return True
else:
if isinstance(self.value, str):
if isinstance(self.axis, UnsliceableDatacubeAxis):
return False
else:
if other.value - 2 * other.axis.tol <= self.value <= other.value + 2 * other.axis.tol:
Expand All @@ -94,7 +95,6 @@ def __eq__(self, other):
return True
else:
return False
# return (self.axis.name, self.value) == (other.axis.name, other.value)

def __lt__(self, other):
return (self.axis.name, self.value) < (other.axis.name, other.value)
Expand Down Expand Up @@ -169,10 +169,12 @@ def intersect(self, other):

def pprint(self, level=0):
if self.axis.name == "root":
print("\n")
print("\t" * level + "\u21b3" + str(self))
logging.debug("\n")
logging.debug("\t" * level + "\u21b3" + str(self))
for child in self.children:
child.pprint(level + 1)
if len(self.children) == 0:
logging.debug("\t" * (level + 1) + "\u21b3" + str(self.result))

def remove_branch(self):
if not self.is_root():
Expand Down
41 changes: 41 additions & 0 deletions tests/test_point_nearest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def setup_method(self, method):
"date": {"merge": {"with": "time", "linkers": ["T", "00"]}},
"step": {"type_change": "int"},
"number": {"type_change": "int"},
"longitude": {"cyclic": [0, 360]},
}
self.config = {"class": "od", "expver": "0001", "levtype": "sfc", "stream": "oper"}
self.fdbdatacube = FDBDatacube(self.config, axis_options=self.options)
Expand Down Expand Up @@ -76,3 +77,43 @@ def test_fdb_datacube_true_point_2(self):
result = self.API.retrieve(request)
result.pprint()
assert len(result.leaves) == 3

@pytest.mark.fdb
def test_fdb_datacube_true_point_3(self):
request = Request(
Select("step", [21]),
Select("levtype", ["sfc"]),
Select("date", [pd.Timestamp("20231102T000000")]),
Select("domain", ["g"]),
Select("expver", ["0001"]),
Select("param", ["167"]),
Select("class", ["od"]),
Select("stream", ["oper"]),
Select("type", ["fc"]),
Point(["latitude", "longitude"], [[0.035149384216, -0.01]], method="nearest"),
)
result = self.API.retrieve(request)
result.pprint()
assert len(result.leaves) == 1
assert result.leaves[0].value == 0
assert result.leaves[0].axis.name == "longitude"

@pytest.mark.fdb
def test_fdb_datacube_true_point_4(self):
request = Request(
Select("step", [21]),
Select("levtype", ["sfc"]),
Select("date", [pd.Timestamp("20231102T000000")]),
Select("domain", ["g"]),
Select("expver", ["0001"]),
Select("param", ["167"]),
Select("class", ["od"]),
Select("stream", ["oper"]),
Select("type", ["fc"]),
Point(["latitude", "longitude"], [[0.035149384216, 359.97]], method="nearest"),
)
result = self.API.retrieve(request)
result.pprint()
assert len(result.leaves) == 1
assert result.leaves[0].value == 359.929906542056
assert result.leaves[0].axis.name == "longitude"
1 change: 1 addition & 0 deletions tests/test_slice_date_range_fdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def test_fdb_datacube(self):
Select("class", ["od"]),
Select("stream", ["oper"]),
Select("type", ["an"]),
Select("number", [1]),
Box(["latitude", "longitude"], [0, 0], [0.2, 0.2]),
)
result = self.API.retrieve(request)
Expand Down

0 comments on commit 9b14fcb

Please sign in to comment.