diff --git a/docs/examples/30_degree_rule.pct.py b/docs/examples/30_degree_rule.pct.py
index 69f9bad8..07543606 100644
--- a/docs/examples/30_degree_rule.pct.py
+++ b/docs/examples/30_degree_rule.pct.py
@@ -75,7 +75,7 @@
#
# We'll start with a table. Since we don't want collisions with cushions to interfere with our trajectory, let's make an unrealistically large $10\text{m} \times 10\text{m}$ [Table](../autoapi/pooltool/index.rst#pooltool.Table).
-# %% trusted=true
+# %%
import pooltool as pt
table_specs = pt.objects.BilliardTableSpecs(l=10, w=10)
@@ -85,20 +85,20 @@
# %% [markdown]
# Next, we'll create two [Ball](../autoapi/pooltool/index.rst#pooltool.Ball) objects.
-# %% trusted=true
+# %%
cue_ball = pt.Ball.create("cue", xy=(2.5, 1.5))
obj_ball = pt.Ball.create("obj", xy=(2.5, 3.0))
# %% [markdown]
# Next, we'll need a [Cue](../autoapi/pooltool/index.rst#pooltool.Cue).
-# %% trusted=true
+# %%
cue = pt.Cue(cue_ball_id="cue")
# %% [markdown]
# Finally, we'll need to wrap these objects up into a [System](../autoapi/pooltool/index.rst#pooltool.System). We'll call this our system *template*, with the intention of reusing it for many different shots.
-# %% trusted=true
+# %%
system_template = pt.System(
table=table,
cue=cue,
@@ -112,7 +112,7 @@
#
# So in the function call below, `pt.aim.at_ball(system, "obj", cut=30)` returns the angle `phi` that the cue ball should be directed at such that a cut angle of 30 degrees with the object ball is achieved.
-# %% trusted=true
+# %%
# Creates a deep copy of the template
system = system_template.copy()
@@ -144,7 +144,7 @@
#
# Since that can't be embedded into the documentation, we'll instead plot the trajectory of the cue ball and object ball by accessing ther historical states.
-# %% trusted=true
+# %%
cue_ball = system.balls["cue"]
obj_ball = system.balls["obj"]
cue_history = cue_ball.history_cts
@@ -154,7 +154,7 @@
# %% [markdown]
# The [BallHistory](../autoapi/pooltool/objects/index.rst#pooltool.objects.BallHistory) holds the ball's historical states, each stored as a [BallState](../autoapi/pooltool/objects/index.rst#pooltool.objects.BallState) object. Each attribute of the ball states can be concatenated into numpy arrays with the [BallHistory.vectorize](../autoapi/pooltool/objects/index.rst#pooltool.objects.BallHistory.vectorize) method.
-# %% trusted=true
+# %%
rvw_cue, s_cue, t_cue = cue_history.vectorize()
rvw_obj, s_obj, t_obj = obj_history.vectorize()
@@ -165,12 +165,12 @@
# %% [markdown]
# We can grab the xy-coordinates from the `rvw` array by with the following.
-# %% trusted=true
+# %%
coords_cue = rvw_cue[:, 0, :2]
coords_obj = rvw_obj[:, 0, :2]
coords_cue.shape
-# %% trusted=true editable=true slideshow={"slide_type": ""} tags=[]
+# %% editable=true slideshow={"slide_type": ""} tags=[]
import plotly.graph_objects as go
import plotly.io as pio
@@ -203,7 +203,7 @@
#
# As mentioned before, the carom angle is the angle between the cue ball velocity right before collision, and the cue ball velocity post-collision, once the ball has stopped sliding on the cloth. Hidden somewhere in the system **event list** one can find the events corresponding to these precise moments in time:
-# %% trusted=true
+# %%
system.events[:6]
# %% [markdown]
@@ -211,20 +211,20 @@
#
# Since there is only one ball-ball collision, it's easy to select with [filter_type](../autoapi/pooltool/events/index.rst#pooltool.events.filter_type):
-# %% trusted=true
+# %%
collision = pt.events.filter_type(system.events, pt.EventType.BALL_BALL)[0]
collision
# %% [markdown]
# To get the event when the cue ball stops sliding, we can similarly try filtering by the sliding to rolling transition event:
-# %% trusted=true
+# %%
pt.events.filter_type(system.events, pt.EventType.SLIDING_ROLLING)
# %% [markdown]
# But there are many sliding to rolling transition events, and to make matters worse, they are shared by both the cue ball and the object ball. What we need is the **first** **sliding to rolling** transition that the **cue ball** undergoes **after** the **ball-ball** collision. We can achieve this multi-criteria query with [filter_events](../autoapi/pooltool/events/index.rst#pooltool.events.filter_events):
-# %% trusted=true
+# %%
transition = pt.events.filter_events(
system.events,
pt.events.by_time(t=collision.time, after=True),
@@ -236,16 +236,13 @@
# %% [markdown]
# Now, we can dive into these two events and pull out the cue ball velocities we need to calculate the carom angle.
-# %% trusted=true
+# %%
# Velocity prior to impact
-for agent in collision.agents:
- if agent.id == "cue":
- # agent.initial is a copy of the Ball before resolving the collision
- velocity_initial = agent.initial.state.rvw[1, :2]
+velocity_initial = collision.get_ball("cue", initial=True).vel[:2]
# Velocity post sliding
-# We choose `final` here for posterity, but the velocity is the same both before and after resolving the transition.
-velocity_final = transition.agents[0].final.state.rvw[1, :2]
+# We choose the "final" here for posterity, but the velocity is the same both before and after resolving the transition.
+velocity_final = transition.get_ball("cue", initial=False).vel[:2]
carom_angle = pt.ptmath.utils.angle_between_vectors(velocity_final, velocity_initial)
@@ -259,7 +256,7 @@
# We calculated the carom angle for a single cut angle, 30 degrees. Let's write a function called `get_carom_angle` so we can do that repeatedly for different cut angles.
-# %% trusted=true
+# %%
def get_carom_angle(system: pt.System) -> float:
assert system.simulated
@@ -283,7 +280,7 @@ def get_carom_angle(system: pt.System) -> float:
# `get_carom_angle` assumes the passed system has already been simulated, so we'll need another function to take care of that. We'll cue stick speed and cut angle as parameters.
-# %% trusted=true
+# %%
def simulate_experiment(V0: float, cut_angle: float) -> pt.System:
system = system_template.copy()
phi = pt.aim.at_ball(system, "obj", cut=cut_angle)
@@ -296,7 +293,7 @@ def simulate_experiment(V0: float, cut_angle: float) -> pt.System:
# We'll also want the ball hit fraction:
-# %% trusted=true
+# %%
import numpy as np
def get_ball_hit_fraction(cut_angle: float) -> float:
@@ -306,7 +303,7 @@ def get_ball_hit_fraction(cut_angle: float) -> float:
# %% [markdown]
# With these functions, we are ready to simulate how carom angle varies as a function of cut angle.
-# %% trusted=true
+# %%
import pandas as pd
data = {
@@ -329,7 +326,7 @@ def get_ball_hit_fraction(cut_angle: float) -> float:
# %% [markdown]
# From this dataframe we can make some plots. On top of the ball-hit fraction, plot, I'll create a box between a $1/4$ ball hit and a $3/4$ ball hit, since this is the carom angle range that the 30-degree rule is defined with respect to.
-# %% trusted=true editable=true slideshow={"slide_type": ""} tags=["nbsphinx-thumbnail"]
+# %% editable=true slideshow={"slide_type": ""} tags=["nbsphinx-thumbnail"]
import matplotlib.pyplot as plt
x_min = 0.25
@@ -354,7 +351,7 @@ def get_ball_hit_fraction(cut_angle: float) -> float:
#
# For your reference, here is the same plot but with cut angle $\phi$ as the x-axis:
-# %% trusted=true
+# %%
fig, ax = plt.subplots()
ax.scatter(frame['phi'], frame['theta'], color='#1f77b4')
ax.set_title('Carom Angle vs Cut Angle', fontsize=20)
@@ -378,7 +375,7 @@ def get_ball_hit_fraction(cut_angle: float) -> float:
# Since pooltool's baseline physics engine makes the same assumptions, we should expect the results to be the same. Let's directly compare:
-# %% trusted=true
+# %%
def get_theoretical_carom_angle(phi) -> float:
return np.atan2(np.sin(phi) * np.cos(phi), (np.sin(phi) ** 2 + 2 / 5))
@@ -425,7 +422,7 @@ def get_theoretical_carom_angle(phi) -> float:
#
# Interestingly, the carom angle is independent of the speed:
-# %% trusted=true
+# %%
for V0 in np.linspace(1, 4, 20):
system = simulate_experiment(V0, 30)
carom_angle = get_carom_angle(system)
@@ -434,7 +431,7 @@ def get_theoretical_carom_angle(phi) -> float:
# %% [markdown]
# This doesn't mean that the trajectories are the same though. Here are the trajectories:
-# %% trusted=true
+# %%
import numpy as np
import plotly.graph_objects as go
diff --git a/docs/examples/straight_shot.pct.py b/docs/examples/straight_shot.pct.py
index dc1e3e57..85f7680d 100644
--- a/docs/examples/straight_shot.pct.py
+++ b/docs/examples/straight_shot.pct.py
@@ -115,7 +115,7 @@ def create_system(d, D):
obj_ball = create_object_ball(cue_ball, d)
return pt.System(
- cue=pt.Cue.default(),
+ cue=pt.Cue(cue_ball_id="CB"),
balls=(cue_ball, obj_ball),
table=table,
)
diff --git a/pooltool/ani/animate.py b/pooltool/ani/animate.py
index ca3cb2e4..8d818bac 100755
--- a/pooltool/ani/animate.py
+++ b/pooltool/ani/animate.py
@@ -506,7 +506,8 @@ def create_system(self):
game = get_ruleset(game_type)()
game.players = [
- Player("Player"),
+ Player("Player 1"),
+ Player("Player 2"),
]
table = Table.from_game_type(game_type)
diff --git a/pooltool/events/datatypes.py b/pooltool/events/datatypes.py
index 741d6b9e..58b7b410 100644
--- a/pooltool/events/datatypes.py
+++ b/pooltool/events/datatypes.py
@@ -1,7 +1,7 @@
from __future__ import annotations
from functools import partial
-from typing import Any, Dict, Optional, Tuple, Type, Union
+from typing import Any, Dict, Optional, Tuple, Type, Union, cast
from attrs import define, evolve, field
from cattrs.converters import Converter
@@ -75,6 +75,35 @@ def is_transition(self) -> bool:
EventType.SLIDING_ROLLING,
}
+ def has_ball(self) -> bool:
+ """Returns True if this event type can involve a Ball."""
+ return (
+ self
+ in {
+ EventType.BALL_BALL,
+ EventType.BALL_LINEAR_CUSHION,
+ EventType.BALL_CIRCULAR_CUSHION,
+ EventType.BALL_POCKET,
+ EventType.STICK_BALL,
+ }
+ or self.is_transition()
+ )
+
+ def has_cushion(self) -> bool:
+ """Returns True if this event type can involve a cushion (linear or circular)."""
+ return self in {
+ EventType.BALL_LINEAR_CUSHION,
+ EventType.BALL_CIRCULAR_CUSHION,
+ }
+
+ def has_pocket(self) -> bool:
+ """Returns True if this event type can involve a Pocket."""
+ return self == EventType.BALL_POCKET
+
+ def has_stick(self) -> bool:
+ """Returns True if this event type can involve a CueStick."""
+ return self == EventType.STICK_BALL
+
Object = Union[
NullObject,
@@ -185,22 +214,6 @@ def set_final(self, obj: Object) -> None:
else:
self.final = obj.copy()
- def matches(self, obj: Object) -> bool:
- """Determines if the given object matches the agent.
-
- It checks if the object is of the correct class type and if the IDs match.
-
- Args:
- obj: The object to compare with the agent.
-
- Returns:
- bool:
- True if the object's class type and ID match the agent's type and ID,
- False otherwise.
- """
- correct_class = _class_to_type[type(obj)] == self.agent_type
- return correct_class and obj.id == self.id
-
@staticmethod
def from_object(obj: Object, set_initial: bool = False) -> Agent:
"""Creates an agent instance from an object.
@@ -228,6 +241,17 @@ def copy(self) -> Agent:
"""Create a copy."""
return evolve(self)
+ def _get_state(self, initial: bool) -> Object:
+ """Return either the initial or final state of the given agent.
+
+ Raises ValueError if that state is None.
+ """
+ obj = self.initial if initial else self.final
+ if obj is None:
+ which = "initial" if initial else "final"
+ raise ValueError(f"Agent '{self.id}' has no {which} state in this event.")
+ return obj
+
def _disambiguate_agent_structuring(
uo: Dict[str, Any], _: Type[Agent], con: Converter
@@ -329,3 +353,81 @@ def copy(self) -> Event:
"""Create a copy."""
# NOTE is this deep-ish copy?
return evolve(self)
+
+ def _find_agent(self, agent_type: AgentType, agent_id: str) -> Agent:
+ """Return the Agent with the specified agent_type and ID.
+
+ Raises:
+ ValueError if not found.
+ """
+ for agent in self.agents:
+ if agent.agent_type == agent_type and agent.id == agent_id:
+ return agent
+ raise ValueError(
+ f"No agent of type {agent_type} with ID '{agent_id}' found in this event."
+ )
+
+ def get_ball(self, ball_id: str, initial: bool = True) -> Ball:
+ """Return the Ball object with the given ID, either final or initial.
+
+ Args:
+ ball_id: The ID of the ball to retrieve.
+ initial: If True, return the ball's initial state; otherwise final state.
+
+ Raises:
+ ValueError: If the event does not involve a ball or if no matching ball is found.
+ """
+ if not self.event_type.has_ball():
+ raise ValueError(
+ f"Event of type {self.event_type} does not involve a Ball."
+ )
+
+ agent = self._find_agent(AgentType.BALL, ball_id)
+ obj = agent._get_state(initial)
+ return cast(Ball, obj)
+
+ def get_pocket(self, pocket_id: str, initial: bool = True) -> Pocket:
+ """Return the Pocket object with the given ID, either final or initial."""
+ if not self.event_type.has_pocket():
+ raise ValueError(
+ f"Event of type {self.event_type} does not involve a Pocket."
+ )
+
+ agent = self._find_agent(AgentType.POCKET, pocket_id)
+ obj = agent._get_state(initial)
+ return cast(Pocket, obj)
+
+ def get_cushion(
+ self, cushion_id: str
+ ) -> Union[LinearCushionSegment, CircularCushionSegment]:
+ """Return the cushion segment with the given ID."""
+ if not self.event_type.has_cushion():
+ raise ValueError(
+ f"Event of type {self.event_type} does not involve a cushion."
+ )
+
+ try:
+ agent = self._find_agent(AgentType.LINEAR_CUSHION_SEGMENT, cushion_id)
+ return cast(LinearCushionSegment, agent.initial)
+ except ValueError:
+ pass
+
+ try:
+ agent = self._find_agent(AgentType.CIRCULAR_CUSHION_SEGMENT, cushion_id)
+ return cast(CircularCushionSegment, agent.initial)
+ except ValueError:
+ pass
+
+ raise ValueError(
+ f"No agent of linear/circular cushion with ID '{cushion_id}' found in this event."
+ )
+
+ def get_stick(self, stick_id: str) -> Pocket:
+ """Return the cue stick with the given ID."""
+ if not self.event_type.has_pocket():
+ raise ValueError(
+ f"Event of type {self.event_type} does not involve a Pocket."
+ )
+
+ agent = self._find_agent(AgentType.POCKET, stick_id)
+ return cast(Pocket, agent.initial)
diff --git a/pooltool/evolution/continuize.py b/pooltool/evolution/continuize.py
index 139baa64..4ac1ccab 100644
--- a/pooltool/evolution/continuize.py
+++ b/pooltool/evolution/continuize.py
@@ -150,12 +150,7 @@ def continuize(system: System, dt: float = 0.01, inplace: bool = False) -> Syste
# We need to get the ball's outgoing state from the event. We'll
# evolve the system from this state.
- for agent in events[count].agents:
- if agent.matches(ball):
- state = agent.final.state.copy() # type: ignore
- break
- else:
- raise ValueError("No agents in event match ball")
+ state = events[count].get_ball(ball.id, initial=False).state.copy()
rvw, s = state.rvw, state.s
diff --git a/pooltool/layouts.py b/pooltool/layouts.py
index 82905b74..5a4d902b 100755
--- a/pooltool/layouts.py
+++ b/pooltool/layouts.py
@@ -17,7 +17,7 @@
DEFAULT_STANDARD_BALLSET = get_ballset("pooltool_pocket")
DEFAULT_SNOOKER_BALLSET = get_ballset("generic_snooker")
-DEFAULT_THREECUSH_BALLSET = None
+DEFAULT_THREECUSH_BALLSET = get_ballset("billiard")
DEFAULT_SUMTOTHREE_BALLSET = None
diff --git a/pooltool/models/balls/billiard/red.blend b/pooltool/models/balls/billiard/red.blend
new file mode 100644
index 00000000..460f17d2
Binary files /dev/null and b/pooltool/models/balls/billiard/red.blend differ
diff --git a/pooltool/models/balls/billiard/red.glb b/pooltool/models/balls/billiard/red.glb
new file mode 100644
index 00000000..49aede71
Binary files /dev/null and b/pooltool/models/balls/billiard/red.glb differ
diff --git a/pooltool/models/balls/billiard/red.png b/pooltool/models/balls/billiard/red.png
new file mode 100644
index 00000000..7595f337
Binary files /dev/null and b/pooltool/models/balls/billiard/red.png differ
diff --git a/pooltool/models/balls/billiard/red.svg b/pooltool/models/balls/billiard/red.svg
new file mode 100644
index 00000000..3caa93ca
--- /dev/null
+++ b/pooltool/models/balls/billiard/red.svg
@@ -0,0 +1,382 @@
+
+
diff --git a/pooltool/models/balls/billiard/shadow.blend b/pooltool/models/balls/billiard/shadow.blend
new file mode 100644
index 00000000..886dafe9
Binary files /dev/null and b/pooltool/models/balls/billiard/shadow.blend differ
diff --git a/pooltool/models/balls/billiard/shadow.glb b/pooltool/models/balls/billiard/shadow.glb
new file mode 100644
index 00000000..bfc497c6
Binary files /dev/null and b/pooltool/models/balls/billiard/shadow.glb differ
diff --git a/pooltool/models/balls/billiard/white.blend b/pooltool/models/balls/billiard/white.blend
new file mode 100644
index 00000000..1afc400c
Binary files /dev/null and b/pooltool/models/balls/billiard/white.blend differ
diff --git a/pooltool/models/balls/billiard/white.glb b/pooltool/models/balls/billiard/white.glb
new file mode 100644
index 00000000..8cd8f55a
Binary files /dev/null and b/pooltool/models/balls/billiard/white.glb differ
diff --git a/pooltool/models/balls/billiard/white.png b/pooltool/models/balls/billiard/white.png
new file mode 100644
index 00000000..b2d96bab
Binary files /dev/null and b/pooltool/models/balls/billiard/white.png differ
diff --git a/pooltool/models/balls/billiard/white.svg b/pooltool/models/balls/billiard/white.svg
new file mode 100644
index 00000000..4386e42c
--- /dev/null
+++ b/pooltool/models/balls/billiard/white.svg
@@ -0,0 +1,382 @@
+
+
diff --git a/pooltool/models/balls/billiard/yellow.blend b/pooltool/models/balls/billiard/yellow.blend
new file mode 100644
index 00000000..c7622984
Binary files /dev/null and b/pooltool/models/balls/billiard/yellow.blend differ
diff --git a/pooltool/models/balls/billiard/yellow.glb b/pooltool/models/balls/billiard/yellow.glb
new file mode 100644
index 00000000..9de835e5
Binary files /dev/null and b/pooltool/models/balls/billiard/yellow.glb differ
diff --git a/pooltool/models/balls/billiard/yellow.png b/pooltool/models/balls/billiard/yellow.png
new file mode 100644
index 00000000..78e4a505
Binary files /dev/null and b/pooltool/models/balls/billiard/yellow.png differ
diff --git a/pooltool/models/balls/billiard/yellow.svg b/pooltool/models/balls/billiard/yellow.svg
new file mode 100644
index 00000000..7df55afd
--- /dev/null
+++ b/pooltool/models/balls/billiard/yellow.svg
@@ -0,0 +1,382 @@
+
+
diff --git a/pooltool/objects/ball/params.py b/pooltool/objects/ball/params.py
index f1b8efe4..94e00f92 100644
--- a/pooltool/objects/ball/params.py
+++ b/pooltool/objects/ball/params.py
@@ -165,12 +165,12 @@ class PrebuiltBallParams(StrEnum):
),
PrebuiltBallParams.BILLIARD_GENERIC: BallParams(
m=0.210,
- R=0.03048,
- u_s=0.5,
+ R=0.0615 / 2,
+ u_s=0.2,
u_r=0.01,
u_sp_proportionality=10 * 2 / 5 / 9,
e_c=0.85,
- f_c=0.5,
+ f_c=0.15,
g=9.81,
),
}
diff --git a/pooltool/objects/table/collection.py b/pooltool/objects/table/collection.py
index 90dbc652..89d83354 100644
--- a/pooltool/objects/table/collection.py
+++ b/pooltool/objects/table/collection.py
@@ -58,20 +58,21 @@ class TableName(StrEnum):
lights_height=1.99,
model_descr=TableModelDescr(name="snooker_generic"),
),
+ # https://web.archive.org/web/20130801042614/http://www.umb.org/Rules/Carom_Rules.pdf
TableName.BILLIARD_WIP: BilliardTableSpecs(
- l=3.05,
- w=3.05 / 2,
+ l=2.84,
+ w=2.84 / 2,
cushion_width=2 * 2.54 / 100,
- cushion_height=0.64 * 2 * 0.028575,
+ cushion_height=0.037,
height=0.708,
lights_height=1.99,
model_descr=TableModelDescr.null(),
),
TableName.SUMTOTHREE_WIP: BilliardTableSpecs(
- l=3.05 / 2.5,
- w=3.05 / 2 / 2.5,
+ l=2.84,
+ w=2.84 / 2,
cushion_width=2 * 2.54 / 100,
- cushion_height=0.64 * 2 * 0.028575,
+ cushion_height=0.037,
height=0.708,
lights_height=1.99,
model_descr=TableModelDescr.null(),
@@ -91,6 +92,7 @@ class TableName(StrEnum):
GameType.SNOOKER: TableName.SNOOKER_GENERIC,
GameType.THREECUSHION: TableName.BILLIARD_WIP,
GameType.SUMTOTHREE: TableName.SUMTOTHREE_WIP,
+ GameType.SANDBOX: TableName.SEVEN_FOOT_SHOWOOD,
}
diff --git a/pooltool/physics/resolve/ball_ball/frictional_inelastic/__init__.py b/pooltool/physics/resolve/ball_ball/frictional_inelastic/__init__.py
index 0a19f6c7..60296b03 100644
--- a/pooltool/physics/resolve/ball_ball/frictional_inelastic/__init__.py
+++ b/pooltool/physics/resolve/ball_ball/frictional_inelastic/__init__.py
@@ -28,28 +28,17 @@ def _resolve_ball_ball(rvw1, rvw2, R, u_b, e_b):
rvw2[1] = ptmath.coordinate_rotation(rvw2[1], -theta)
rvw2[2] = ptmath.coordinate_rotation(rvw2[2], -theta)
- rvw1_f = rvw1.copy()
- rvw2_f = rvw2.copy()
-
# velocity normal component, same for both slip and no-slip after collison cases
v1_n_f = 0.5 * ((1.0 - e_b) * rvw1[1][0] + (1.0 + e_b) * rvw2[1][0])
v2_n_f = 0.5 * ((1.0 + e_b) * rvw1[1][0] + (1.0 - e_b) * rvw2[1][0])
D_v_n_magnitude = abs(v2_n_f - v1_n_f)
- # angular velocity normal component, unchanged
- w1_n_f = rvw1[2][0]
- w2_n_f = rvw2[2][0]
-
- # discard normal components for now
+ # discard velocity normal components for now
# so that surface velocities are tangent
rvw1[1][0] = 0.0
- rvw1[2][0] = 0.0
rvw2[1][0] = 0.0
- rvw2[2][0] = 0.0
- rvw1_f[1][0] = 0.0
- rvw1_f[2][0] = 0.0
- rvw2_f[1][0] = 0.0
- rvw2_f[2][0] = 0.0
+ rvw1_f = rvw1.copy()
+ rvw2_f = rvw2.copy()
v1_c = ptmath.surface_velocity(rvw1, unit_x, R)
v2_c = ptmath.surface_velocity(rvw2, -unit_x, R)
@@ -77,12 +66,11 @@ def _resolve_ball_ball(rvw1, rvw2, R, u_b, e_b):
# then slip condition is invalid so we need to calculate no-slip condition
if not has_relative_velocity or np.dot(v12_c, v12_c_slip) <= 0: # type: ignore
# velocity tangent component for no-slip condition
- D_v1_t = -(1.0 / 9.0) * (
- 2.0 * (rvw1[1] - rvw2[1])
- + R * ptmath.cross(2.0 * rvw1[2] + 7.0 * rvw2[2], unit_x)
+ D_v1_t = -(1.0 / 7.0) * (
+ rvw1[1] - rvw2[1] + R * ptmath.cross(rvw1[2] + rvw2[2], unit_x)
)
- D_w1 = (5.0 / 9.0) * (
- rvw2[2] - rvw1[2] + ptmath.cross(unit_x, rvw2[1] - rvw1[1]) / R
+ D_w1 = -(5.0 / 14.0) * (
+ ptmath.cross(unit_x, rvw1[1] - rvw2[1]) / R + rvw1[2] + rvw2[2]
)
rvw1_f[1] = rvw1[1] + D_v1_t
rvw1_f[2] = rvw1[2] + D_w1
@@ -92,8 +80,6 @@ def _resolve_ball_ball(rvw1, rvw2, R, u_b, e_b):
# reintroduce the final normal components
rvw1_f[1][0] = v1_n_f
rvw2_f[1][0] = v2_n_f
- rvw1_f[2][0] = w1_n_f
- rvw2_f[2][0] = w2_n_f
# rotate everything back to the original frame
rvw1_f[1] = ptmath.coordinate_rotation(rvw1_f[1], theta)
diff --git a/pooltool/ruleset/three_cushion.py b/pooltool/ruleset/three_cushion.py
index 693d075f..41951667 100644
--- a/pooltool/ruleset/three_cushion.py
+++ b/pooltool/ruleset/three_cushion.py
@@ -24,36 +24,46 @@ def _other(cue: str, event: Event) -> str:
raise Exception()
-def is_turn_over(shot: System, constraints: ShotConstraints) -> bool:
- assert constraints.cueable is not None
- cue = constraints.cueable[0]
-
- # Find when the second ball is first hit by the cue-ball
+def is_point(shot: System) -> bool:
+ cue_id = shot.cue.cue_ball_id
- ball_hits = filter_events(
+ # Get collisions of the cue ball with the object balls.
+ cb_ob_collisions = filter_events(
shot.events,
by_type(EventType.BALL_BALL),
- by_ball(cue),
+ by_ball(cue_id),
)
- hits = set()
- for event in ball_hits:
- hits.add(_other(cue, event))
- if len(hits) == 2:
+ hit_ob_ids = set()
+ for event in cb_ob_collisions:
+ hit_ob_ids.add(_other(cue_id, event))
+
+ if len(hit_ob_ids) == 2:
+ # This is the first (and perhaps only) instance of the cue ball hitting the
+ # second object ball.
+ second_ob_collision = event
break
else:
- return True
+ # Both object balls were not contacted by the cue ball. No point.
+ return False
- # Now calculate all cue-ball cushion hits before that event
+ # Both balls have been hit by the object ball. But were at least 3 cushions
+ # contacted before the second object ball was first hit? If yes, point, otherwise
+ # no.
cushion_hits = filter_events(
shot.events,
by_type(EventType.BALL_LINEAR_CUSHION),
- by_ball(cue),
- by_time(event.time, after=False),
+ by_ball(cue_id),
+ by_time(second_ob_collision.time, after=False),
)
- return len(cushion_hits) < 3
+ return len(cushion_hits) >= 3
+
+
+def is_turn_over(shot: System, constraints: ShotConstraints) -> bool:
+ assert constraints.cueable is not None
+ return not is_point(shot)
def is_game_over(
diff --git a/tests/events/example_system.msgpack b/tests/events/example_system.msgpack
new file mode 100644
index 00000000..36ad1082
Binary files /dev/null and b/tests/events/example_system.msgpack differ
diff --git a/tests/events/test_datatypes.py b/tests/events/test_datatypes.py
new file mode 100644
index 00000000..d660f4a2
--- /dev/null
+++ b/tests/events/test_datatypes.py
@@ -0,0 +1,128 @@
+from pathlib import Path
+from typing import List
+
+import pytest
+
+from pooltool.events.datatypes import Event, EventType
+from pooltool.objects.ball.datatypes import Ball
+from pooltool.objects.table.components import (
+ CircularCushionSegment,
+ LinearCushionSegment,
+ Pocket,
+)
+from pooltool.system.datatypes import System
+
+
+@pytest.fixture
+def example_events() -> List[Event]:
+ """
+ Returns the list of Event objects from simulating the example system.
+ """
+ return System.load(Path(__file__).parent / "example_system.msgpack").events
+
+
+def test_get_ball_success(example_events: List[Event]):
+ """
+ Find an event that involves a ball (e.g. BALL_BALL or STICK_BALL)
+ and verify we can retrieve the ball by ID.
+ """
+ # We'll look for a BALL_BALL event that (based on your snippet) should have agents: ('cue', '1')
+ event = next(e for e in example_events if e.event_type == EventType.BALL_BALL)
+
+ # Try retrieving the ball named "cue"
+ cue_ball = event.get_ball("cue", initial=False) # final state by default
+ assert isinstance(cue_ball, Ball)
+ assert cue_ball.id == "cue"
+
+ # Also retrieve the "1" ball by initial state
+ ball_1_initial = event.get_ball("1", initial=True)
+ assert isinstance(ball_1_initial, Ball)
+ assert ball_1_initial.id == "1"
+
+
+def test_get_ball_no_ball_in_event(example_events: List[Event]):
+ """
+ Attempt to retrieve a ball from an event type that doesn't involve a ball, expecting ValueError.
+ """
+ null_event = example_events[0]
+ assert null_event.event_type == EventType.NONE
+
+ with pytest.raises(ValueError, match="does not involve a Ball"):
+ null_event.get_ball("dummy")
+
+
+def test_get_ball_wrong_id(example_events: List[Event]):
+ """
+ Attempt to retrieve a ball using an ID not present in a ball-involving event.
+ """
+ event = next(e for e in example_events if e.event_type == EventType.STICK_BALL)
+
+ with pytest.raises(ValueError, match="No agent of type ball"):
+ event.get_ball("1")
+
+
+def test_get_cushion_success(example_events: List[Event]):
+ """
+ Find a BALL_LINEAR_CUSHION or BALL_CIRCULAR_CUSHION event and verify we can retrieve the cushion.
+ """
+ # Agents: ('cue','6')
+ linear_event = next(
+ e for e in example_events if e.event_type == EventType.BALL_LINEAR_CUSHION
+ )
+
+ cushion_obj = linear_event.get_cushion("6")
+ assert isinstance(cushion_obj, LinearCushionSegment)
+ assert cushion_obj.id == "6"
+
+ # Agents ('cue', '8t')
+ circular_event = next(
+ e for e in example_events if e.event_type == EventType.BALL_CIRCULAR_CUSHION
+ )
+ cushion_obj_circ = circular_event.get_cushion("8t")
+ assert isinstance(cushion_obj_circ, CircularCushionSegment)
+ assert cushion_obj_circ.id == "8t"
+
+
+def test_get_cushion_not_in_event(example_events: List[Event]):
+ """
+ Attempt to retrieve a cushion from an event that doesn't involve one.
+ """
+ event = next(e for e in example_events if e.event_type == EventType.BALL_BALL)
+ with pytest.raises(ValueError, match="does not involve a cushion"):
+ event.get_cushion("8t")
+
+
+def test_get_pocket_success(example_events: List[Event]):
+ """
+ Find a BALL_POCKET event (agents: ('1','rt') in your snippet) and retrieve the pocket.
+ """
+ pocket_event = next(
+ e for e in example_events if e.event_type == EventType.BALL_POCKET
+ )
+ pocket_obj = pocket_event.get_pocket("rt", initial=False)
+ assert isinstance(pocket_obj, Pocket)
+ assert pocket_obj.id == "rt"
+
+
+def test_get_pocket_not_in_event(example_events: List[Event]):
+ """
+ Attempt to retrieve a pocket from a non-pocket event, expecting ValueError.
+ """
+ event = next(e for e in example_events if e.event_type == EventType.BALL_BALL)
+ with pytest.raises(
+ ValueError, match="Event of type ball_ball does not involve a Pocket"
+ ):
+ event.get_pocket("rt")
+
+
+def test_get_pocket_missing_id(example_events: List[Event]):
+ """
+ Attempt to retrieve a pocket with an ID that doesn't match the event's pocket.
+ """
+ pocket_event = next(
+ e for e in example_events if e.event_type == EventType.BALL_POCKET
+ )
+ with pytest.raises(
+ ValueError, match="No agent of type pocket with ID 'non_existent_pocket_id'"
+ ):
+ pocket_event.get_pocket("non_existent_pocket_id")
diff --git a/tests/physics/resolve/ball_ball/test_ball_ball.py b/tests/physics/resolve/ball_ball/test_ball_ball.py
index 2e843916..2623d291 100644
--- a/tests/physics/resolve/ball_ball/test_ball_ball.py
+++ b/tests/physics/resolve/ball_ball/test_ball_ball.py
@@ -1,3 +1,4 @@
+import math
from typing import Tuple
import attrs
@@ -12,27 +13,55 @@
from pooltool.physics.resolve.ball_ball.frictionless_elastic import FrictionlessElastic
-def head_on() -> Tuple[Ball, Ball]:
- cb = Ball.create("cue", xy=(0, 0))
+def velocity_from_speed_and_xy_direction(speed: float, angle_radians: float):
+ """Convert speed and angle to a velocity vector
+
+ Angle is defined CCW from the x-axis in the xy-plane
+ """
+ return speed * np.array([math.cos(angle_radians), math.sin(angle_radians), 0.0])
+
+
+def gearing_z_spin_for_incoming_ball(incoming_ball):
+ """Calculate the amount of sidespin (z-axis spin) required for gearing contact
+ with no relative surface velocity.
+
+ Assumes line of centers from incoming ball to object ball is along the x-axis.
+
+ In order for gearing contact to occur, the sidespin must cancel out any
+ velocity in the tangential (y-axis) direction.
+
+ And from angular velocity equations where
+ 'r' is distance from the rotation center
+ 'w' is angular velocity
+ 'v' is tangential velocity at a distance from the rotation center
+
+ v = w * R
+
+ So, v_y + w_z * R = 0, and therefore w_z = -v_y / R
+ """
+ return -incoming_ball.vel[1] / incoming_ball.params.R
- # Cue ball makes head-on collision with object ball at 1 m/s in +x direction
- cb.state.rvw[1] = np.array([1, 0, 0])
+def ball_collision() -> Tuple[Ball, Ball]:
+ cb = Ball.create("cue", xy=(0, 0))
ob = Ball.create("cue", xy=(2 * cb.params.R, 0))
assert cb.params.m == ob.params.m, "Balls expected to be equal mass"
return cb, ob
-def translating_head_on() -> Tuple[Ball, Ball]:
- cb = Ball.create("cue", xy=(0, 0))
- ob = Ball.create("cue", xy=(2 * cb.params.R, 0))
+def head_on() -> Tuple[Ball, Ball]:
+ cb, ob = ball_collision()
+ # Cue ball makes head-on collision with object ball at 1 m/s in +x direction
+ cb.state.rvw[1] = np.array([1, 0, 0])
+ return cb, ob
+
+def translating_head_on() -> Tuple[Ball, Ball]:
+ cb, ob = ball_collision()
# Cue ball makes head-on collision with object ball at 1 m/s in +x direction
# while both balls move together at 1 m/s in +y direction
cb.state.rvw[1] = np.array([1, 1, 0])
ob.state.rvw[1] = np.array([0, 1, 0])
-
- assert cb.params.m == ob.params.m, "Balls expected to be equal mass"
return cb, ob
@@ -90,7 +119,7 @@ def test_translating_head_on_zero_spin_inelastic(
cb_f, ob_f = model.resolve(cb_i, ob_i, inplace=False)
# Balls should still be moving together in +y direction
- assert np.isclose(cb_f.vel[1], ob_f.vel[1], atol=1e-10)
+ assert abs(cb_f.vel[1] - ob_f.vel[1]) < 1e-10
@pytest.mark.parametrize("model", [FrictionalInelastic(), FrictionalMathavan()])
@@ -111,3 +140,93 @@ def test_head_on_z_spin(model: BallBallCollisionStrategy, cb_wz_i: float):
cb_wz_f = cb_f.state.rvw[2][2]
assert cb_wz_f > 0, "Spin direction shouldn't reverse"
assert cb_wz_f < cb_wz_i, "Spin should be decay"
+
+
+@pytest.mark.parametrize(
+ "model", [FrictionalInelastic(), FrictionalMathavan(num_iterations=int(1e5))]
+)
+@pytest.mark.parametrize("speed", np.logspace(-1, 1, 5))
+@pytest.mark.parametrize(
+ "cut_angle_radians", np.linspace(0, math.pi / 2.0, 8, endpoint=False)
+)
+def test_gearing_z_spin(
+ model: BallBallCollisionStrategy, speed: float, cut_angle_radians: float
+):
+ """Ensure that a gearing collision causes no throw or induced spin.
+
+ A gearing collision is one where the relative surface speed between the balls is 0.
+ In other words, the velocity of each ball at the contact point is identical, and there is no
+ slip at the contact point.
+ """
+
+ unit_x = np.array([1.0, 0.0, 0.0])
+ cb_i, ob_i = ball_collision()
+
+ cb_i.state.rvw[1] = velocity_from_speed_and_xy_direction(speed, cut_angle_radians)
+ cb_i.state.rvw[2][2] = gearing_z_spin_for_incoming_ball(cb_i)
+
+ # sanity check the initial conditions
+ v_c = (
+ ptmath.surface_velocity(
+ cb_i.state.rvw, np.array([1.0, 0.0, 0.0]), cb_i.params.R
+ )
+ - cb_i.vel[0] * unit_x
+ )
+ assert ptmath.norm3d(v_c) < 1e-10, "Relative surface contact speed should be zero"
+
+ cb_f, ob_f = model.resolve(cb_i, ob_i, inplace=False)
+
+ assert (
+ abs(math.atan2(ob_f.vel[1], ob_f.vel[0])) < 1e-3
+ ), "Gearing english shouldn't cause throw"
+ assert abs(ob_f.avel[2]) < 1e-3, "Gearing english shouldn't cause induced side-spin"
+
+
+@pytest.mark.parametrize("model", [FrictionalInelastic()])
+@pytest.mark.parametrize("speed", np.logspace(0, 1, 5))
+@pytest.mark.parametrize(
+ "cut_angle_radians", np.linspace(0, math.pi / 2.0, 8, endpoint=False)
+)
+@pytest.mark.parametrize("relative_surface_speed", np.linspace(0, 0.05, 5))
+def test_low_relative_surface_velocity(
+ model: BallBallCollisionStrategy,
+ speed: float,
+ cut_angle_radians: float,
+ relative_surface_speed: float,
+):
+ """Ensure that collisions with a "small" relative surface velocity end with 0 relative surface velocity.
+ In other words, that the balls are gearing after the collision.
+
+ Note that how small the initial relative surface velocity needs to be for this condition to be met is dependent
+ on model parameters and initial conditions such as ball-ball friction and the collision speed along the line of centers.
+ """
+
+ unit_x = np.array([1.0, 0.0, 0.0])
+ cb_i, ob_i = ball_collision()
+
+ cb_i.state.rvw[1] = velocity_from_speed_and_xy_direction(speed, cut_angle_radians)
+ cb_i.state.rvw[2][2] = gearing_z_spin_for_incoming_ball(cb_i)
+ cb_i.state.rvw[2][2] += (
+ relative_surface_speed / cb_i.params.R
+ ) # from v = w * R -> w = v / R
+
+ # sanity check the initial conditions
+ v_c = (
+ ptmath.surface_velocity(cb_i.state.rvw, unit_x, cb_i.params.R)
+ - cb_i.vel[0] * unit_x
+ )
+ assert (
+ abs(relative_surface_speed - ptmath.norm3d(v_c)) < 1e-10
+ ), f"Relative surface contact speed should be {relative_surface_speed}"
+
+ cb_f, ob_f = model.resolve(cb_i, ob_i, inplace=False)
+
+ cb_v_c_f = ptmath.surface_velocity(
+ cb_f.state.rvw, unit_x, cb_f.params.R
+ ) - np.array([cb_f.vel[0], 0.0, 0.0])
+ ob_v_c_f = ptmath.surface_velocity(
+ ob_f.state.rvw, -unit_x, ob_f.params.R
+ ) - np.array([ob_f.vel[0], 0.0, 0.0])
+ assert (
+ ptmath.norm3d(cb_v_c_f - ob_v_c_f) < 1e-3
+ ), "Final relative contact velocity should be zero"
diff --git a/tests/ruleset/test_shots/01_test_shot_no_point.msgpack b/tests/ruleset/test_shots/01_test_shot_no_point.msgpack
new file mode 100644
index 00000000..915cad11
Binary files /dev/null and b/tests/ruleset/test_shots/01_test_shot_no_point.msgpack differ
diff --git a/tests/ruleset/test_shots/01a_test_shot_no_point.msgpack b/tests/ruleset/test_shots/01a_test_shot_no_point.msgpack
new file mode 100644
index 00000000..0c4aab28
Binary files /dev/null and b/tests/ruleset/test_shots/01a_test_shot_no_point.msgpack differ
diff --git a/tests/ruleset/test_shots/02_test_shot_ispoint.msgpack b/tests/ruleset/test_shots/02_test_shot_ispoint.msgpack
new file mode 100644
index 00000000..0e7ac238
Binary files /dev/null and b/tests/ruleset/test_shots/02_test_shot_ispoint.msgpack differ
diff --git a/tests/ruleset/test_shots/02a_test_shot_ispoint.msgpack b/tests/ruleset/test_shots/02a_test_shot_ispoint.msgpack
new file mode 100644
index 00000000..7f5c3cef
Binary files /dev/null and b/tests/ruleset/test_shots/02a_test_shot_ispoint.msgpack differ
diff --git a/tests/ruleset/test_shots/03_test_shot_ispoint.msgpack b/tests/ruleset/test_shots/03_test_shot_ispoint.msgpack
new file mode 100644
index 00000000..2b005135
Binary files /dev/null and b/tests/ruleset/test_shots/03_test_shot_ispoint.msgpack differ
diff --git a/tests/ruleset/test_shots/03a_test_shot_ispoint.msgpack b/tests/ruleset/test_shots/03a_test_shot_ispoint.msgpack
new file mode 100644
index 00000000..b1cbec3e
Binary files /dev/null and b/tests/ruleset/test_shots/03a_test_shot_ispoint.msgpack differ
diff --git a/tests/ruleset/test_shots/04_test_shot_no_point.msgpack b/tests/ruleset/test_shots/04_test_shot_no_point.msgpack
new file mode 100644
index 00000000..cdc7e03b
Binary files /dev/null and b/tests/ruleset/test_shots/04_test_shot_no_point.msgpack differ
diff --git a/tests/ruleset/test_shots/04a_test_shot_no_point.msgpack b/tests/ruleset/test_shots/04a_test_shot_no_point.msgpack
new file mode 100644
index 00000000..74475770
Binary files /dev/null and b/tests/ruleset/test_shots/04a_test_shot_no_point.msgpack differ
diff --git a/tests/ruleset/test_shots/05_test_shot_ispoint.msgpack b/tests/ruleset/test_shots/05_test_shot_ispoint.msgpack
new file mode 100644
index 00000000..ac886a69
Binary files /dev/null and b/tests/ruleset/test_shots/05_test_shot_ispoint.msgpack differ
diff --git a/tests/ruleset/test_shots/05a_test_shot_ispoint.msgpack b/tests/ruleset/test_shots/05a_test_shot_ispoint.msgpack
new file mode 100644
index 00000000..7e65d65f
Binary files /dev/null and b/tests/ruleset/test_shots/05a_test_shot_ispoint.msgpack differ
diff --git a/tests/ruleset/test_three_cushion.py b/tests/ruleset/test_three_cushion.py
new file mode 100644
index 00000000..e5567343
--- /dev/null
+++ b/tests/ruleset/test_three_cushion.py
@@ -0,0 +1,38 @@
+from pathlib import Path
+
+from pooltool.ruleset.three_cushion import is_point
+from pooltool.system.datatypes import System
+
+root = Path(__file__).parent
+
+
+def test_three_cushion():
+ shot = System.load(root / "test_shots/01_test_shot_no_point.msgpack")
+ assert not is_point(shot)
+
+ shot = System.load(root / "test_shots/01a_test_shot_no_point.msgpack")
+ assert not is_point(shot)
+
+ shot = System.load(root / "test_shots/02_test_shot_ispoint.msgpack")
+ assert is_point(shot)
+
+ shot = System.load(root / "test_shots/02a_test_shot_ispoint.msgpack")
+ assert is_point(shot)
+
+ shot = System.load(root / "test_shots/03_test_shot_ispoint.msgpack")
+ assert is_point(shot)
+
+ shot = System.load(root / "test_shots/03a_test_shot_ispoint.msgpack")
+ assert is_point(shot)
+
+ shot = System.load(root / "test_shots/04_test_shot_no_point.msgpack")
+ assert not is_point(shot)
+
+ shot = System.load(root / "test_shots/04a_test_shot_no_point.msgpack")
+ assert not is_point(shot)
+
+ shot = System.load(root / "test_shots/05_test_shot_ispoint.msgpack")
+ assert is_point(shot)
+
+ shot = System.load(root / "test_shots/05a_test_shot_ispoint.msgpack")
+ assert is_point(shot)