diff --git a/pooltool/physics/resolve/resolver.py b/pooltool/physics/resolve/resolver.py index dcfe43e6..e496dda1 100644 --- a/pooltool/physics/resolve/resolver.py +++ b/pooltool/physics/resolve/resolver.py @@ -2,10 +2,13 @@ from __future__ import annotations +import shutil +import traceback from pathlib import Path from typing import Optional import attrs +from cattrs.errors import ClassValidationError import pooltool.user_config from pooltool.events.datatypes import AgentType, Event, EventType @@ -120,40 +123,60 @@ def load(cls, path: Pathish) -> Resolver: @classmethod def default(cls) -> Resolver: """Load ~/.config/pooltool/physics/resolver.yaml if exists, create otherwise""" - if RESOLVER_PATH.exists(): - resolver = cls.load(RESOLVER_PATH) - if resolver.version == VERSION: - return resolver - else: - run.info_single( - f"{RESOLVER_PATH} is has version {resolver.version}, which is not up to " - f"date with the most current version: {VERSION}. It will be replaced with the " - f"default." - ) - - resolver = cls( - ball_ball=FrictionalMathavan( - friction=AlciatoreBallBallFriction( - a=0.009951, - b=0.108, - c=1.088, + def _default_config(): + return cls( + ball_ball=FrictionalMathavan( + friction=AlciatoreBallBallFriction( + a=0.009951, + b=0.108, + c=1.088, + ), + num_iterations=1000, + ), + ball_linear_cushion=Han2005Linear(), + ball_circular_cushion=Han2005Circular(), + ball_pocket=CanonicalBallPocket(), + stick_ball=InstantaneousPoint( + english_throttle=1.0, + squirt_throttle=1.0, ), - num_iterations=1000, - ), - ball_linear_cushion=Han2005Linear(), - ball_circular_cushion=Han2005Circular(), - ball_pocket=CanonicalBallPocket(), - stick_ball=InstantaneousPoint( - english_throttle=1.0, - squirt_throttle=1.0, - ), - transition=CanonicalTransition(), - version=VERSION, - ) - - resolver.save(RESOLVER_PATH) - return resolver + transition=CanonicalTransition(), + version=VERSION, + ) + + if not RESOLVER_PATH.exists(): + resolver = _default_config() + resolver.save(RESOLVER_PATH) + return resolver + + try: + resolver = cls.load(RESOLVER_PATH) + except ClassValidationError: + full_traceback = traceback.format_exc() + dump_path = RESOLVER_PATH.parent / f".{RESOLVER_PATH.name}" + run.info_single( + f"{RESOLVER_PATH} is malformed and can't be loaded. It is being " + f"replaced with a default working version. Your version has been moved to " + f"{dump_path} if you want to diagnose it. Here is the error:\n{full_traceback}" + ) + shutil.move(RESOLVER_PATH, dump_path) + resolver = _default_config() + resolver.save(RESOLVER_PATH) + + if resolver.version == VERSION: + return resolver + else: + dump_path = RESOLVER_PATH.parent / f".{RESOLVER_PATH.name}" + run.info_single( + f"{RESOLVER_PATH} is has version {resolver.version}, which is not up to " + f"date with the most current version: {VERSION}. It will be replaced with the " + f"default. Your version has been moved to {dump_path}." + ) + shutil.move(RESOLVER_PATH, dump_path) + resolver = _default_config() + resolver.save(RESOLVER_PATH) + return resolver def _snapshot_initial(shot: System, event: Event) -> None: