Skip to content

Commit

Permalink
Refactor Resolver serialization + ball ball friction sub models (#171)
Browse files Browse the repository at this point in the history
* add TP A-14 speed-dependent ball-ball friction; still needs configurable friction parameters

* Remove ResolverConfig, directly serialize/deserialize Resolver instead

* Update doc up to creating your own model

* Cattrs structuring failure leads to default fallback with warning

---------

Co-authored-by: Derek McBlane <mcblanederek@gmail.com>
  • Loading branch information
ekiefl and derek-mcblane authored Jan 8, 2025
1 parent 9069d7f commit f5bd153
Show file tree
Hide file tree
Showing 24 changed files with 1,840 additions and 1,688 deletions.
14 changes: 6 additions & 8 deletions docs/examples/30_degree_rule.pct.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.16.4
# jupytext_version: 1.16.6
# kernelspec:
# display_name: Python 3 (ipykernel)
# language: python
Expand Down Expand Up @@ -122,10 +122,10 @@
# %% [markdown]
# Now, we [simulate](../autoapi/pooltool/index.rst#pooltool.simulate) the shot and then [continuize](../autoapi/pooltool/evolution/continuize/index.html#pooltool.evolution.continuize.continuize) it to store ball state data (like coordinates) in $10\text{ms}$ timestep intervals.

# %%
# Create a default physics engine and overwrite ball-ball model with frictionless, elastic model.
# %% trusted=true
# Create a default physics engine, then overwrite ball-ball model with frictionless, elastic model.
engine = pt.physics.PhysicsEngine()
engine.resolver.ball_ball = pt.physics.get_ball_ball_model(pt.physics.BallBallModel.FRICTIONLESS_ELASTIC)
engine.resolver.ball_ball = pt.physics.ball_ball_models[pt.physics.BallBallModel.FRICTIONLESS_ELASTIC]()

pt.simulate(system, engine=engine, inplace=True)
pt.continuize(system, dt=0.01, inplace=True)
Expand Down Expand Up @@ -268,10 +268,8 @@ def get_carom_angle(system: pt.System) -> float:
pt.events.by_type(pt.EventType.SLIDING_ROLLING),
)[0]

velocity_final = transition.agents[0].final.state.rvw[1, :2]
for agent in collision.agents:
if agent.id == "cue":
velocity_initial = agent.initial.state.rvw[1, :2]
velocity_final = transition.get_ball("cue", initial=False).state.rvw[1, :2]
velocity_initial = transition.get_ball("cue", initial=True).state.rvw[1, :2]

return pt.ptmath.utils.angle_between_vectors(velocity_final, velocity_initial)

Expand Down
267 changes: 127 additions & 140 deletions docs/resources/custom_physics.md

Large diffs are not rendered by default.

2,286 changes: 1,168 additions & 1,118 deletions poetry.lock

Large diffs are not rendered by default.

33 changes: 14 additions & 19 deletions pooltool/physics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,37 +5,35 @@
from pooltool.physics.resolve.ball_ball import (
BallBallCollisionStrategy,
BallBallModel,
get_ball_ball_model,
ball_ball_models,
)
from pooltool.physics.resolve.ball_cushion import (
BallCCushionCollisionStrategy,
BallCCushionModel,
BallLCushionCollisionStrategy,
BallLCushionModel,
get_ball_circ_cushion_model,
get_ball_lin_cushion_model,
ball_ccushion_models,
ball_lcushion_models,
)
from pooltool.physics.resolve.ball_pocket import (
BallPocketModel,
BallPocketStrategy,
get_ball_pocket_model,
ball_pocket_models,
)
from pooltool.physics.resolve.resolver import (
RESOLVER_CONFIG_PATH,
RESOLVER_PATH,
Resolver,
ResolverConfig,
)
from pooltool.physics.resolve.stick_ball import (
StickBallCollisionStrategy,
StickBallModel,
get_stick_ball_model,
stick_ball_models,
)
from pooltool.physics.resolve.transition import (
BallTransitionModel,
BallTransitionStrategy,
get_transition_model,
ball_transition_models,
)
from pooltool.physics.resolve.types import ArgType, ModelArgs

__all__ = [
"BallBallCollisionStrategy",
Expand All @@ -48,20 +46,17 @@
"evolve",
"resolve",
"Resolver",
"RESOLVER_CONFIG_PATH",
"ResolverConfig",
"RESOLVER_PATH",
"BallBallModel",
"get_ball_ball_model",
"BallCCushionModel",
"BallLCushionModel",
"get_ball_circ_cushion_model",
"get_ball_lin_cushion_model",
"BallPocketModel",
"get_ball_pocket_model",
"StickBallModel",
"get_stick_ball_model",
"BallTransitionModel",
"get_transition_model",
"ArgType",
"ModelArgs",
"ball_ball_models",
"ball_lcushion_models",
"ball_ccushion_models",
"ball_pocket_models",
"stick_ball_models",
"ball_transition_models",
]
74 changes: 58 additions & 16 deletions pooltool/physics/resolve/__init__.py
Original file line number Diff line number Diff line change
@@ -1,59 +1,101 @@
"""Resolve events"""

import inspect

import attrs

from pooltool.physics.resolve.ball_ball import (
BallBallCollisionStrategy,
BallBallModel,
get_ball_ball_model,
ball_ball_models,
)
from pooltool.physics.resolve.ball_cushion import (
BallCCushionCollisionStrategy,
BallCCushionModel,
BallLCushionCollisionStrategy,
BallLCushionModel,
get_ball_circ_cushion_model,
get_ball_lin_cushion_model,
ball_ccushion_models,
ball_lcushion_models,
)
from pooltool.physics.resolve.ball_pocket import (
BallPocketModel,
BallPocketStrategy,
get_ball_pocket_model,
ball_pocket_models,
)
from pooltool.physics.resolve.resolver import (
RESOLVER_CONFIG_PATH,
RESOLVER_PATH,
Resolver,
ResolverConfig,
)
from pooltool.physics.resolve.stick_ball import (
StickBallCollisionStrategy,
StickBallModel,
get_stick_ball_model,
stick_ball_models,
)
from pooltool.physics.resolve.transition import (
BallTransitionModel,
BallTransitionStrategy,
get_transition_model,
ball_transition_models,
)


def _display_model(cls, model):
fp = inspect.getfile(cls)
print(f" {model.value} ({fp})")

if not attrs.has(cls):
raise TypeError(f"{cls.__name__} is not an attrs class.")

indent = 4
indent_str = " " * indent

for field in attrs.fields(cls):
if field.name == "model":
continue

default_val = field.default
if default_val is attrs.NOTHING:
default_val = None

print(
f"{indent_str} - {field.name}: "
f"type={field.type}, default={default_val}"
)


def display_models():
print("\nball_ball models:")
for model in BallBallModel:
_display_model(ball_ball_models[model], model)
print("\nball_linear_cushion models:")
for model in BallLCushionModel:
_display_model(ball_lcushion_models[model], model)
print("\nball_circular_cushion models:")
for model in BallCCushionModel:
_display_model(ball_ccushion_models[model], model)
print("\nstick_ball models:")
for model in StickBallModel:
_display_model(stick_ball_models[model], model)
print("\nball_pocket models:")
for model in BallPocketModel:
_display_model(ball_pocket_models[model], model)
print("\nball_transition models:")
for model in BallTransitionModel:
_display_model(ball_transition_models[model], model)


__all__ = [
"Resolver",
"RESOLVER_CONFIG_PATH",
"ResolverConfig",
"RESOLVER_PATH",
"BallBallCollisionStrategy",
"BallBallModel",
"get_ball_ball_model",
"BallCCushionCollisionStrategy",
"BallCCushionModel",
"BallLCushionCollisionStrategy",
"BallLCushionModel",
"get_ball_circ_cushion_model",
"get_ball_lin_cushion_model",
"BallPocketModel",
"BallPocketStrategy",
"get_ball_pocket_model",
"StickBallCollisionStrategy",
"StickBallModel",
"get_stick_ball_model",
"BallTransitionModel",
"BallTransitionStrategy",
"get_transition_model",
]
58 changes: 13 additions & 45 deletions pooltool/physics/resolve/ball_ball/__init__.py
Original file line number Diff line number Diff line change
@@ -1,63 +1,31 @@
"""Models for ball-ball collisions."""

from typing import Dict, Optional, Type
from typing import Dict, Tuple, Type, cast

import attrs

from pooltool.physics.resolve.ball_ball.core import BallBallCollisionStrategy
from pooltool.physics.resolve.ball_ball.frictional_inelastic import FrictionalInelastic
from pooltool.physics.resolve.ball_ball.frictional_mathavan import FrictionalMathavan
from pooltool.physics.resolve.ball_ball.frictionless_elastic import FrictionlessElastic
from pooltool.physics.resolve.types import ModelArgs
from pooltool.utils.strenum import StrEnum, auto


class BallBallModel(StrEnum):
"""An Enum for different ball-ball collision models
from pooltool.physics.resolve.models import BallBallModel

Attributes:
FRICTIONLESS_ELASTIC:
Frictionless, instantaneous, elastic, equal mass collision
(:class:`FrictionlessElastic`).
"""
_ball_ball_model_registry: Tuple[Type[BallBallCollisionStrategy], ...] = (
FrictionlessElastic,
FrictionalMathavan,
FrictionalInelastic,
)

FRICTIONLESS_ELASTIC = auto()
FRICTIONAL_INELASTIC = auto()
FRICTIONAL_MATHAVAN = auto()


_ball_ball_models: Dict[BallBallModel, Type[BallBallCollisionStrategy]] = {
BallBallModel.FRICTIONLESS_ELASTIC: FrictionlessElastic,
BallBallModel.FRICTIONAL_INELASTIC: FrictionalInelastic,
BallBallModel.FRICTIONAL_MATHAVAN: FrictionalMathavan,
ball_ball_models: Dict[BallBallModel, Type[BallBallCollisionStrategy]] = {
cast(BallBallModel, attrs.fields_dict(cls)["model"].default): cls
for cls in _ball_ball_model_registry
}


def get_ball_ball_model(
model: Optional[BallBallModel] = None, params: ModelArgs = {}
) -> BallBallCollisionStrategy:
"""Returns a ball-ball collision model
Args:
model:
An Enum specifying the desired model. If not passed,
:class:`FrictionalMathavan` is passed with empty params.
params:
A mapping of parameters accepted by the model.
Returns:
An instantiated model that satisfies the :class:`BallBallCollisionStrategy`
protocol.
"""

if model is None:
return FrictionlessElastic()

return _ball_ball_models[model](**params)


__all__ = [
"BallBallModel",
"get_ball_ball_model",
"FrictionalMathavan",
"FrictionalInelastic",
"FrictionlessElastic",
"ball_ball_models",
]
66 changes: 66 additions & 0 deletions pooltool/physics/resolve/ball_ball/friction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import math
from typing import Dict, Protocol, Type

import attrs
import numpy as np

import pooltool.ptmath as ptmath
from pooltool.objects.ball.datatypes import Ball
from pooltool.utils.strenum import StrEnum, auto


class BallBallFrictionModel(StrEnum):
"""An Enum for different ball-ball friction models"""

AVERAGE = auto()
ALCIATORE = auto()


class BallBallFrictionStrategy(Protocol):
"""Ball-ball friction models must satisfy this protocol"""

def calculate_friction(self, ball1: Ball, ball2: Ball) -> float:
"""This method calculates ball-ball friction"""
...


@attrs.define
class AlciatoreBallBallFriction:
"""Friction fit curve u_b = a + b * exp(-c * v_rel) used in David Alciatore's TP A-14"""

a: float = 9.951e-3
b: float = 0.108
c: float = 1.088

model: BallBallFrictionModel = attrs.field(
default=BallBallFrictionModel.ALCIATORE, init=False, repr=False
)

def calculate_friction(self, ball1: Ball, ball2: Ball) -> float:
unit_x = np.array([1.0, 0.0, 0.0])
v1_c = ptmath.surface_velocity(
ball1.state.rvw, unit_x, ball1.params.R
) - np.array([ball1.state.rvw[1][0], 0, 0])
v2_c = ptmath.surface_velocity(
ball2.state.rvw, -unit_x, ball2.params.R
) - np.array([ball2.state.rvw[1][0], 0, 0])
relative_surface_speed = ptmath.norm3d(v1_c - v2_c)
return self.a + self.b * math.exp(-self.c * relative_surface_speed)


@attrs.define
class AverageBallBallFriction:
model: BallBallFrictionModel = attrs.field(
default=BallBallFrictionModel.AVERAGE, init=False, repr=False
)

def calculate_friction(self, ball1: Ball, ball2: Ball) -> float:
return (ball1.params.u_b + ball2.params.u_b) / 2


ball_ball_friction_models: Dict[
BallBallFrictionModel, Type[BallBallFrictionStrategy]
] = {
BallBallFrictionModel.AVERAGE: AverageBallBallFriction,
BallBallFrictionModel.ALCIATORE: AlciatoreBallBallFriction,
}
Loading

0 comments on commit f5bd153

Please sign in to comment.