From ff5d6f90976ff6475b5d718d737a0c9a277397a9 Mon Sep 17 00:00:00 2001 From: Siddharth Narayanan Date: Sun, 29 Dec 2024 19:07:09 -0800 Subject: [PATCH] fixing issue with deserializating logprob --- ldp/data_structures.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/ldp/data_structures.py b/ldp/data_structures.py index cfc64206..f696d174 100644 --- a/ldp/data_structures.py +++ b/ldp/data_structures.py @@ -4,6 +4,7 @@ import logging import os from collections.abc import Callable, Hashable, Iterable +from contextlib import suppress from typing import Any, ClassVar, Self, cast from uuid import UUID @@ -121,7 +122,12 @@ def from_jsonl(cls, filename: str | os.PathLike) -> Self: reader = iter(f) traj = cls(traj_id=json.loads(next(reader))) for json_line in reader: - traj.steps.append(Transition(**json.loads(json_line))) + data = json.loads(json_line) + # logprob may have been serialized, but cnanot be passed to + # OpResult, so remove it here. + with suppress(KeyError): + data["action"].pop("logprob") + traj.steps.append(Transition(**data)) return traj def compute_discounted_returns(self, discount: float = 1.0) -> list[float]: