-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
first go at adding observation transform
- Loading branch information
Showing
2 changed files
with
206 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
# (C) Copyright 2024 Anemoi contributors. | ||
# | ||
# This software is licensed under the terms of the Apache Licence Version 2.0 | ||
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. | ||
# | ||
# In applying this licence, ECMWF does not waive the privileges and immunities | ||
# granted to it by virtue of its status as an intergovernmental organisation | ||
# nor does it submit to any jurisdiction. | ||
|
||
from anemoi.transform.filters import filter_registry | ||
from anemoi.transform.sources import source_registry | ||
from anemoi.transform.workflows import workflow_registry | ||
|
||
################ | ||
|
||
mars = source_registry.create("mars") | ||
|
||
r = dict( | ||
_class="ea", | ||
expver="0001", | ||
stream="oper", | ||
obsgroup="conv", | ||
reportype="16001/16002", | ||
date="20241212", | ||
type="ofb", | ||
time="00/06/12/18", | ||
filter="'select reportype,seqno,date,time,lat,lon,report_status,report_event1,entryno,varno,statid,stalt,obsvalue,lsm@modsurf,biascorr_fg,final_obs_error,datum_status@body,datum_event1@body,vertco_reference_1,vertco_type where ((varno==39 and abs(fg_depar@body)<20) or (varno in (41,42) and abs(fg_depar@body)<15) or (varno==58 and abs(fg_depar@body)<0.4) or (varno == 110 and entryno == 1 and abs(fg_depar@body)<10000)) and time in (000000,030000,060000,090000,120000,150000,180000,210000)'", | ||
) | ||
|
||
data = mars.forward(r) | ||
|
||
print(data) | ||
|
||
################ | ||
|
||
odb2df = filter_registry.create("reshape_odb_df", | ||
predicted_cols=["obsvalue@body"], | ||
pivot_cols=["varno@body"], | ||
meta_cols=["reportype", "stalt@hdr", "lsm@modsurf"], | ||
drop_nans=True) | ||
|
||
data = odb2df.forward(data) | ||
print(data) | ||
|
||
################ | ||
|
||
pipeline = workflow_registry.create("pipeline", filters=[mars, odb2df]) | ||
print(pipeline) | ||
|
||
################ | ||
|
||
pipeline = r | mars | odb2df | ||
print(pipeline) | ||
|
||
# ipipe = pipeline.to_infernece() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
# (C) Copyright 2024 Anemoi contributors. | ||
# | ||
# This software is licensed under the terms of the Apache Licence Version 2.0 | ||
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. | ||
# | ||
# In applying this licence, ECMWF does not waive the privileges and immunities | ||
# granted to it by virtue of its status as an intergovernmental organisation | ||
# nor does it submit to any jurisdiction. | ||
|
||
import json | ||
import logging | ||
import os.path | ||
from typing import List | ||
|
||
import pandas as pd | ||
from earthkit.data.readers.odb import ODBReader | ||
from . import filter_registry | ||
from .base import SimpleFilter | ||
|
||
INDEX_COL = "seqno@hdr" | ||
GEOLOCATION_META_COLS = ["lat@hdr", "lon@hdr", "date@hdr", "time@hdr"] | ||
VARNO_COL = "varno@body" | ||
|
||
|
||
class ReshapeODBDF(SimpleFilter): | ||
"""A filter to reshape ODB dataframe.""" | ||
|
||
def __init__( | ||
self, | ||
*, | ||
sort_by: List[str] = ["date@hdr", "time@hdr"], | ||
meta_cols: List[str] = [], | ||
meta_body_cols: List[str] = [], | ||
extra_obsval_cols: List[str] = [], | ||
predicted_cols: List[str] = ["obsvalue@body"], | ||
pivot_cols: List[str] = ["varno@body"], | ||
drop_nans: bool = False, | ||
): | ||
self.sort_by = sort_by | ||
self.meta_cols = meta_cols | ||
self.meta_body_cols = meta_body_cols | ||
self.extra_obsval_cols = extra_obsval_cols | ||
self.predicted_cols = predicted_cols | ||
self.pivot_cols = pivot_cols | ||
self.drop_nans = drop_nans | ||
|
||
if not all([self.predicted_cols, self.pivot_cols]): | ||
raise ValueError("'predicted_col' and 'pivot_col' must be specified") | ||
|
||
def forward(self, data): | ||
yield self._transform(data, self.forward_transform) | ||
|
||
def backward(self, data): | ||
raise NotImplementedError("ReshapeODBDF is not reversible") | ||
|
||
def forward_transform(self, data: ODBReader) -> pd.DataFrame: | ||
""" | ||
Restructures a dataframe in the native ODB-schema | ||
i) pivot so that per-channel or per-variable from row-wise to column-wise | ||
ii) renames columns | ||
iii) sorts the data | ||
""" | ||
index_cols = [INDEX_COL] + GEOLOCATION_META_COLS + self.meta_cols | ||
value_cols = self.predicted_cols + self.meta_body_cols | ||
|
||
pivot_colname = ( | ||
["varno@body", "vertco_reference_1@body"] | ||
if self.pivot_cols == ["vertco_reference_1@body"] | ||
else self.pivot_cols | ||
) | ||
|
||
df = data.to_pandas() | ||
df = df.drop_duplicates(subset=index_cols + pivot_colname, keep="first") | ||
|
||
df_pivot = df.pivot(index=index_cols, columns=pivot_colname, values=value_cols) | ||
df_pivot = df_pivot.sort_values(by=self.sort_by, kind="stable").reset_index() | ||
|
||
df_meta = df_pivot[index_cols] | ||
df_obs = df_pivot.drop(columns=index_cols, level=0).sort_index(axis=1) | ||
df_out = pd.concat([df_meta, df_obs], axis=1) | ||
|
||
if self.drop_nans: | ||
df_out = df_out.dropna() | ||
|
||
df_out["datetime"] = pd.to_datetime( | ||
df_out["date@hdr"].astype(int).astype(str) | ||
+ df_out["time@hdr"].astype(int).astype(str).str.zfill(6), | ||
format="%Y%m%d%H%M%S", | ||
) | ||
df_out = df_out.drop(columns=["date@hdr", "time@hdr"], level=0) | ||
|
||
df_out.columns = self.rename_columns(df_out.columns.tolist(), self.extra_obsval_cols) | ||
|
||
return df_out | ||
|
||
def rename_columns(self, tup_list: List, extra_obsval_cols: List[str]) -> List[str]: | ||
""" | ||
Rename the columns using convention: obsvalue_{varno_name}_{vertco_reference_1} | ||
Note: non-obsvalue columns simply have the "@table" stripped from the name. | ||
Args: | ||
tup_list: List of tuples from pandas multi-index column names | ||
e.g. ("obsvalue@body",39) -> "obsvalue_t2m_0" | ||
("obsvalue@body",119,22) -> "obsvalue_rawbt_22" | ||
extra_obsval_cols: List of additional column names to be treated as observation values | ||
Returns: | ||
List of new column names | ||
""" | ||
path = os.path.dirname(os.path.abspath(__file__)) | ||
with open(f"{path}/../data/varno.json") as f: | ||
varno_dict = json.load(f) | ||
|
||
out_colnames = [] | ||
for tup in tup_list: | ||
colname = tup[0] | ||
varno = tup[1] if len(tup) > 1 else "" | ||
vertco_reference_1 = tup[2] if len(tup) > 2 else "" | ||
|
||
base_colname = colname.split("@")[0] | ||
|
||
if base_colname in extra_obsval_cols: | ||
base_colname = f"obsvalue_{base_colname}" | ||
|
||
if not varno: | ||
out_colnames.append(base_colname) | ||
else: | ||
try: | ||
varno_idx = next( | ||
i | ||
for i, varno_lst in enumerate(varno_dict["data"]) | ||
if int(varno) in varno_lst | ||
) | ||
varno_name = varno_dict["data"][varno_idx][0] | ||
vertco_suffix = ( | ||
f"{int(vertco_reference_1)}" if vertco_reference_1 else "0" | ||
) | ||
out_colnames.append(f"{base_colname}_{varno_name}_{vertco_suffix}") | ||
except (ValueError, StopIteration): | ||
logging.warning( | ||
f"Unable to find varno name for {varno}. Using original varno." | ||
) | ||
out_colnames.append(f"{base_colname}_{varno}_{vertco_suffix}") | ||
|
||
return out_colnames | ||
|
||
def backward_transform(self, data: ODBReader) -> None: | ||
raise NotImplementedError("ReshapeODBDF is not reversible") | ||
|
||
|
||
filter_registry.register("reshape_odb_df", ReshapeODBDF) |