diff --git a/examples/cfg_500_000.json b/examples/cfg_500_000.json new file mode 100644 index 0000000..ec7197b --- /dev/null +++ b/examples/cfg_500_000.json @@ -0,0 +1,14 @@ +{ + "PLATE_WIDTH": 1500, + "PLATE_HEIGHT": 1000, + "CATS_N": 500000, + "CAT_RADIUS": 0.4, + "MOVE_RADIUS": 0.6, + "RADIUS_0": 0.6, + "RADIUS_1": 0.9, + "MOVE_PATTERN_ID": "MOVE_PATTERN_PHIS_ID", + "DISTANCE": "EUCLIDEAN_DISTANCE", + "BORDER_INTERACTION": true, + "FAV_CATS_AMOUNT": 1, + "FAV_CATS_OBSERVING": true +} diff --git a/examples/cfg_beautiful.json b/examples/cfg_beautiful.json new file mode 100644 index 0000000..8b4839b --- /dev/null +++ b/examples/cfg_beautiful.json @@ -0,0 +1,11 @@ +{ + "PLATE_WIDTH": 1500, + "PLATE_HEIGHT": 1000, + "CATS_N": 60, + "CAT_RADIUS": 20, + "MOVE_RADIUS": 40, + "RADIUS_0": 40, + "RADIUS_1": 120, + "MOVE_PATTERN_ID": "MOVE_PATTERN_PHIS_ID", + "DISTANCE": "EUCLIDEAN_DISTANCE" +} diff --git a/examples/cfg_border_magic.json b/examples/cfg_border_magic.json new file mode 100644 index 0000000..9c4792f --- /dev/null +++ b/examples/cfg_border_magic.json @@ -0,0 +1,12 @@ +{ + "PLATE_WIDTH": 1500, + "PLATE_HEIGHT": 1000, + "CATS_N": 10, + "CAT_RADIUS": 50, + "MOVE_RADIUS": 100, + "RADIUS_0": 100, + "RADIUS_1": 300, + "MOVE_PATTERN_ID": "MOVE_PATTERN_PHIS_ID", + "DISTANCE": "EUCLIDEAN_DISTANCE", + "FAV_CATS_OBSERVING": false +} diff --git a/src/catsim/__main__.py b/src/catsim/__main__.py index ab2ee25..eb2d7d1 100644 --- a/src/catsim/__main__.py +++ b/src/catsim/__main__.py @@ -1,66 +1,66 @@ +import argparse +from pathlib import Path + import taichi as ti import taichi.math as tm -import catsim.config as cfg +from catsim.config import Config from catsim.cat import Cat, init_cat_env -from catsim.enums import ( - ALWAYS_VISIBLE, - VISIBLE, -) +from catsim.enums import ALWAYS_VISIBLE, VISIBLE, COLORS, COLORS_IGN, COLORS_FAV from catsim.grid import setup_grid, update_statuses from catsim.spawner import Spawner -POINTS = tm.vec2.field(shape=(cfg.CATS_N,)) -COLORS = ti.field(ti.i32, shape=(cfg.CATS_N,)) -RADIUSES = ti.field(ti.f32, shape=(cfg.CATS_N,)) +POINTS: tm.vec2.field +CAT_COLORS: ti.field +RADII: ti.field -LINES1_BEGIN = tm.vec2.field(shape=(cfg.CATS_N,)) -LINES1_END = tm.vec2.field(shape=(cfg.CATS_N,)) +LINES1_BEGIN: tm.vec2.field +LINES1_END: tm.vec2.field -LINES2_BEGIN = tm.vec2.field(shape=(cfg.CATS_N,)) -LINES2_END = tm.vec2.field(shape=(cfg.CATS_N,)) +LINES2_BEGIN: tm.vec2.field +LINES2_END: tm.vec2.field -LINE_LENGTH = tm.vec2([cfg.RADIUS_1 / cfg.PLATE_WIDTH, cfg.RADIUS_1 / cfg.PLATE_HEIGHT]) -ANGLE_SHIFT = tm.vec2( - [ - tm.asin(cfg.CAT_RADIUS / cfg.PLATE_WIDTH / LINE_LENGTH[0] * 1.5), - tm.asin(cfg.CAT_RADIUS / cfg.PLATE_HEIGHT / LINE_LENGTH[1] * 1.5), - ] -) +LINE_LENGTH: tm.vec2 +ANGLE_SHIFT: tm.vec2 @ti.kernel -def arrange_visuals(cats: ti.template()) -> tuple[ti.i32, ti.i32]: +def arrange_visuals( + cats: ti.template(), + fav_cats_amount: ti.i8, + fav_cats_observing: bool, + cats_n: ti.i32, +) -> tuple[ti.i32, ti.i32]: render_idx: ti.i32 = 0 - for cat_idx in range(cfg.FAV_CATS_AMOUNT, cfg.CATS_N): + for cat_idx in range(fav_cats_amount, cats_n): # skip favorite cats for now to render them later # so they appear "above" others cat = cats[cat_idx] if cat.visibility_status == VISIBLE: POINTS[render_idx] = cat.norm_point - RADIUSES[render_idx] = cat.radius - COLORS[render_idx] = ( - cfg.COLORS[cat.status] - if cat.observed or (not cfg.FAV_CATS_OBSERVING) - else cfg.COLORS_IGN[cat.status] + RADII[render_idx] = cat.radius + CAT_COLORS[render_idx] = ( + COLORS[cat.status] + if cat.observed or (not fav_cats_observing) + else COLORS_IGN[cat.status] ) ti.atomic_add(render_idx, 1) line_idx: ti.i32 = 0 - for cat_idx in range(cfg.FAV_CATS_AMOUNT): + for cat_idx in range(fav_cats_amount): cat = cats[cat_idx] assert cat.visibility_status == ALWAYS_VISIBLE POINTS[render_idx] = cat.norm_point - RADIUSES[render_idx] = cat.radius - COLORS[render_idx] = cfg.COLORS_FAV[cat.status] + RADII[render_idx] = cat.radius + CAT_COLORS[render_idx] = COLORS_FAV[cat.status] ti.atomic_add(render_idx, 1) - if cfg.FAV_CATS_OBSERVING: + if fav_cats_observing: LINES1_BEGIN[line_idx] = cat.norm_point LINES2_BEGIN[line_idx] = cat.norm_point @@ -84,17 +84,19 @@ def arrange_visuals(cats: ti.template()) -> tuple[ti.i32, ti.i32]: @ti.kernel -def move_cats(cats: ti.template()): - for idx in range(cfg.CATS_N): +def move_cats(cats: ti.template(), cats_n: ti.i32): + for idx in range(cats_n): cats[idx].move() -def mainloop(cats: ti.template(), gui: ti.GUI): +def mainloop(cfg: Config, cats: ti.template(), gui: ti.GUI): while gui.running: - move_cats(cats) + move_cats(cats, cfg.CATS_N) update_statuses(cats) - count_cats, count_lines = arrange_visuals(cats) + count_cats, count_lines = arrange_visuals( + cats, cfg.FAV_CATS_AMOUNT, cfg.FAV_CATS_OBSERVING, cfg.CATS_N + ) if count_lines != 0: gui.lines( @@ -111,14 +113,14 @@ def mainloop(cats: ti.template(), gui: ti.GUI): if count_cats != 0: gui.circles( pos=POINTS.to_numpy()[:count_cats], - radius=RADIUSES.to_numpy()[:count_cats], - color=COLORS.to_numpy()[:count_cats], + radius=RADII.to_numpy()[:count_cats], + color=CAT_COLORS.to_numpy()[:count_cats], ) gui.show() -def validate_config(): +def validate_config(cfg): if cfg.PLATE_HEIGHT <= 0 or cfg.PLATE_WIDTH <= 0: raise ValueError("Plate height/width must be > 0") @@ -143,8 +145,43 @@ def validate_config(): raise ValueError("Radius 1 must be > Radius 0") +def init_env(cfg: Config): + global POINTS, CAT_COLORS, RADII + POINTS = tm.vec2.field(shape=(cfg.CATS_N,)) + CAT_COLORS = ti.field(ti.i32, shape=(cfg.CATS_N,)) + RADII = ti.field(ti.f32, shape=(cfg.CATS_N,)) + + global LINES1_BEGIN, LINES1_END, LINES2_BEGIN, LINES2_END + LINES1_BEGIN = tm.vec2.field(shape=(cfg.CATS_N,)) + LINES1_END = tm.vec2.field(shape=(cfg.CATS_N,)) + + LINES2_BEGIN = tm.vec2.field(shape=(cfg.CATS_N,)) + LINES2_END = tm.vec2.field(shape=(cfg.CATS_N,)) + + global LINE_LENGTH, ANGLE_SHIFT + LINE_LENGTH = tm.vec2( + [cfg.RADIUS_1 / cfg.PLATE_WIDTH, cfg.RADIUS_1 / cfg.PLATE_HEIGHT] + ) + ANGLE_SHIFT = tm.vec2( + [ + tm.asin(cfg.CAT_RADIUS / cfg.PLATE_WIDTH / LINE_LENGTH[0] * 1.5), + tm.asin(cfg.CAT_RADIUS / cfg.PLATE_HEIGHT / LINE_LENGTH[1] * 1.5), + ] + ) + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument("config_file", type=str) + return parser.parse_args() + + def main(): - validate_config() + args = parse_arguments() + cfg = Config.generate_from_json(Path(args.config_file)) + validate_config(cfg) + + init_env(cfg) init_cat_env( move_radius=cfg.MOVE_RADIUS, @@ -156,6 +193,8 @@ def main(): prob_inter=cfg.PROB_INTERACTION, border_inter=cfg.BORDER_INTERACTION, distance_type=cfg.DISTANCE, + fav_cats_observing=cfg.FAV_CATS_OBSERVING, + observable_angle_span=cfg.OBSERVABLE_ANGLE_SPAN, ) setup_grid( @@ -163,6 +202,8 @@ def main(): r1=cfg.RADIUS_1, width=cfg.PLATE_WIDTH, height=cfg.PLATE_HEIGHT, + fav_cats_amount=cfg.FAV_CATS_AMOUNT, + fav_cats_log=cfg.FAV_CATS_LOGGING, ) cats = Cat.field(shape=(cfg.CATS_N,)) @@ -181,7 +222,7 @@ def main(): for idx in range(cfg.FAV_CATS_AMOUNT): spawner.set_always_visible_cat(idx, 0) - mainloop(cats, gui) + mainloop(cfg, cats, gui) if __name__ == "__main__": diff --git a/src/catsim/cat.py b/src/catsim/cat.py index 12ba0fd..e09b23c 100644 --- a/src/catsim/cat.py +++ b/src/catsim/cat.py @@ -1,7 +1,6 @@ import taichi as ti import taichi.math as tm -from catsim.config import FAV_CATS_OBSERVING, OBSERVABLE_ANGLE_SPAN from catsim.enums import ( ALWAYS_VISIBLE, INTERACTION_LEVEL_0, @@ -35,6 +34,9 @@ _PROB_INTER: bool _BORDER_INTER: bool +_FAV_CATS_OBSERVING: bool +_OBSERVABLE_ANGLE_SPAN: float + def init_cat_env( move_radius: ti.f32, @@ -46,6 +48,8 @@ def init_cat_env( prob_inter: bool, distance_type: ti.i32, border_inter: bool, + fav_cats_observing: bool, + observable_angle_span: float, ): global \ _RADIUS_0, \ @@ -57,7 +61,9 @@ def init_cat_env( _PROB_INTER, \ _BORDER_INTER, \ _DISTANCE_TYPE, \ - _MAX_DISTANCE + _MAX_DISTANCE, \ + _FAV_CATS_OBSERVING, \ + _OBSERVABLE_ANGLE_SPAN _MOVE_RADIUS = move_radius _RADIUS_0 = r0 @@ -68,6 +74,8 @@ def init_cat_env( _PROB_INTER = prob_inter _BORDER_INTER = border_inter _DISTANCE_TYPE = distance_type + _FAV_CATS_OBSERVING = fav_cats_observing + _OBSERVABLE_ANGLE_SPAN = observable_angle_span p0 = tm.vec2([0, 0]) p1 = tm.vec2([0, _PLATE_WIDTH]) @@ -204,14 +212,16 @@ def move(self): direction_angle = tm.atan2( self.point[1] - self.prev_point[1], self.point[0] - self.prev_point[0] ) - self.observable_angle[0] = direction_angle - OBSERVABLE_ANGLE_SPAN - self.observable_angle[1] = direction_angle + OBSERVABLE_ANGLE_SPAN + self.observable_angle[0] = direction_angle - _OBSERVABLE_ANGLE_SPAN + self.observable_angle[1] = direction_angle + _OBSERVABLE_ANGLE_SPAN if _BORDER_INTER: self.update_visibility_status() @ti.func - def fight_with(self, other_cat: ti.template()): + def fight_with(self, other_cat: ti.template()) -> ti.i32: + _st = INTERACTION_NO + if ( other_cat.visibility_status == VISIBLE or other_cat.visibility_status == ALWAYS_VISIBLE @@ -220,10 +230,12 @@ def fight_with(self, other_cat: ti.template()): if dist > _RADIUS_1 or (_PROB_INTER and ti.random() >= 1.0 / (dist * dist)): self.status = ti.max(self.status, INTERACTION_NO) + _st = INTERACTION_NO + pass else: observing = True - if self.visibility_status == ALWAYS_VISIBLE and FAV_CATS_OBSERVING: + if self.visibility_status == ALWAYS_VISIBLE and _FAV_CATS_OBSERVING: relative_angle = tm.atan2( other_cat.point[1] - self.point[1], other_cat.point[0] - self.point[0], @@ -238,8 +250,10 @@ def fight_with(self, other_cat: ti.template()): observing = False if observing: - self.status = ( - INTERACTION_LEVEL_0 - if dist <= _RADIUS_0 - else ti.max(self.status, INTERACTION_LEVEL_1) - ) + if dist <= _RADIUS_0: + self.status = INTERACTION_LEVEL_0 + _st = INTERACTION_LEVEL_0 + else: + self.status = ti.max(self.status, INTERACTION_LEVEL_1) + _st = INTERACTION_LEVEL_1 + return _st diff --git a/src/catsim/config.json b/src/catsim/config.json new file mode 100644 index 0000000..905a401 --- /dev/null +++ b/src/catsim/config.json @@ -0,0 +1,13 @@ +{ + "PLATE_WIDTH": 1500, + "PLATE_HEIGHT": 1000, + "CATS_N": 100, + "CAT_RADIUS": 20, + "MOVE_RADIUS": 40, + "RADIUS_0": 40, + "RADIUS_1": 120, + "MOVE_PATTERN_ID": "MOVE_PATTERN_PHIS", + "DISTANCE": "EUCLIDEAN_DISTANCE", + "FAV_CATS_AMOUNT": 1, + "FAV_CATS_OBSERVING": true +} diff --git a/src/catsim/config.py b/src/catsim/config.py index 20c06ff..3e55945 100644 --- a/src/catsim/config.py +++ b/src/catsim/config.py @@ -1,56 +1,67 @@ -import taichi as ti +from __future__ import annotations -from catsim.enums import ( - EUCLIDEAN_DISTANCE, - INTERACTION_LEVEL_0, - INTERACTION_LEVEL_1, - INTERACTION_NO, - MOVE_PATTERN_PHIS, -) +import importlib +import json +from dataclasses import dataclass +from pathlib import Path +import taichi as ti -# ----- GENERAL ----- # -PLATE_WIDTH, PLATE_HEIGHT = 1500, 1000 -CATS_N = 150 +enums = importlib.import_module("enums") -CAT_RADIUS = 0.02 * PLATE_HEIGHT -MOVE_RADIUS = CAT_RADIUS * 2 -RADIUS_0 = CAT_RADIUS * 2 -RADIUS_1 = RADIUS_0 * 3 +@dataclass +class Config: + # ----- GENERAL ----- + PLATE_WIDTH: int = 1500 + PLATE_HEIGHT: int = 1000 + CATS_N: int = 150 -# ----- FAVORITE CATS ----- # -# 0 <= amount <= CATS_N -FAV_CATS_AMOUNT = 2 -FAV_CATS_OBSERVING = True + # ----- CAT ----- + CAT_RADIUS: float = 0.02 * PLATE_HEIGHT + MOVE_RADIUS: float = CAT_RADIUS * 2 + RADIUS_0: float = CAT_RADIUS * 2 + RADIUS_1: float = RADIUS_0 * 3 -# pi / 8 <= angle <= pi / 2 -OBSERVABLE_ANGLE_SPAN = ti.math.pi / 4 + # ----- PATTERNS ----- + MOVE_PATTERN_ID: int = enums.MOVE_PATTERN_PHIS + DISTANCE: int = enums.EUCLIDEAN_DISTANCE + # ----- INTERACTIONS ----- + PROB_INTERACTION: bool = False + BORDER_INTERACTION: bool = True -# ----- PATTERNS ----- # -MOVE_PATTERN_ID = MOVE_PATTERN_PHIS -DISTANCE = EUCLIDEAN_DISTANCE + # ----- FAVORITE CATS ----- # + # 0 <= amount <= CATS_N + FAV_CATS_AMOUNT: int = 1 + FAV_CATS_OBSERVING: bool = True + FAV_CATS_LOGGING: bool = True + # pi / 8 <= angle <= pi / 2 + OBSERVABLE_ANGLE_SPAN: float = ti.math.pi / 4 -# ----- INTERACTIONS ----- # -PROB_INTERACTION = False -BORDER_INTERACTION = True + # ----- VISUALISATION ----- # + LINES_RADIUS: ti.i32 = CAT_RADIUS // 10 + @staticmethod + def generate_from_json(json_name: Path) -> Config: + data: dict + with open(json_name, "r") as file: + data = json.load(file) -# ----- VISUALISATION ----- # -LINES_RADIUS: ti.i32 = CAT_RADIUS // 10 + _cfg_data = {} -COLORS = ti.field(ti.i32, shape=(3,)) -COLORS[INTERACTION_NO] = 0x34C924 # green -COLORS[INTERACTION_LEVEL_1] = 0xFFFF00 # yellow -COLORS[INTERACTION_LEVEL_0] = 0xED553B # red + keys_with_const = {"MOVE_PATTERN_ID", "DISTANCE"} -COLORS_FAV = ti.field(ti.i32, shape=(3,)) -COLORS_FAV[INTERACTION_NO] = 0x45B39D # light green -COLORS_FAV[INTERACTION_LEVEL_1] = 0xFAD7A0 # light yellow -COLORS_FAV[INTERACTION_LEVEL_0] = 0xF1948A # light red + for key, value in data.items(): + if key in keys_with_const: + try: + _cfg_data[key] = getattr(enums, value) + except AttributeError: + print( + f"WARNING: Attribute '{value}' not found in 'const'. " + f"Therefore the default value for field {key} was used." + ) + else: + _cfg_data[key] = value -COLORS_IGN = ti.field(ti.i32, shape=(3,)) -COLORS_IGN[INTERACTION_NO] = 0x333333 -COLORS_IGN[INTERACTION_LEVEL_1] = 0x545454 -COLORS_IGN[INTERACTION_LEVEL_0] = 0x787878 + return Config(**_cfg_data) diff --git a/src/catsim/enums.py b/src/catsim/enums.py index c8be2a0..9a91623 100644 --- a/src/catsim/enums.py +++ b/src/catsim/enums.py @@ -11,9 +11,9 @@ CHEBYSHEV_DISTANCE = 2 # ----- INTERACTION LEVELS ----- # -INTERACTION_NO = 0 -INTERACTION_LEVEL_1 = 1 -INTERACTION_LEVEL_0 = 2 +INTERACTION_NO: ti.i8 = 0 +INTERACTION_LEVEL_1: ti.i8 = 1 +INTERACTION_LEVEL_0: ti.i8 = 2 # ----- VISIBILITY STATUSES ----- # NEVER_APPEARED: ti.i8 = 0 @@ -21,3 +21,19 @@ INVISIBLE: ti.i8 = 2 ALWAYS_VISIBLE: ti.i8 = 3 ALWAYS_INVISIBLE: ti.i8 = 4 + +# ----- VISUALISATION ----- # +COLORS = ti.field(ti.i32, shape=(3,)) +COLORS[INTERACTION_NO] = 0x34C924 # green +COLORS[INTERACTION_LEVEL_1] = 0xFFFF00 # yellow +COLORS[INTERACTION_LEVEL_0] = 0xED553B # red + +COLORS_FAV = ti.field(ti.i32, shape=(3,)) +COLORS_FAV[INTERACTION_NO] = 0x45B39D # light green +COLORS_FAV[INTERACTION_LEVEL_1] = 0xFAD7A0 # light yellow +COLORS_FAV[INTERACTION_LEVEL_0] = 0xF1948A # light red + +COLORS_IGN = ti.field(ti.i32, shape=(3,)) +COLORS_IGN[INTERACTION_NO] = 0x333333 +COLORS_IGN[INTERACTION_LEVEL_1] = 0x545454 +COLORS_IGN[INTERACTION_LEVEL_0] = 0x787878 diff --git a/src/catsim/grid.py b/src/catsim/grid.py index 61dba3e..34871d6 100644 --- a/src/catsim/grid.py +++ b/src/catsim/grid.py @@ -1,5 +1,6 @@ import math from typing import Any +from catsim.enums import INTERACTION_NO, INTERACTION_LEVEL_0 import taichi as ti @@ -19,6 +20,12 @@ _GRID_COL_N: ti.i32 _GRID_ROW_N: ti.i32 +_FAV_CATS_AMOUNT: ti.i32 +_FAV_CATS_LOGGING: bool +_NEW_LOGS: ti.field +_OLD_LOGS: ti.field +_OLD_STATUS: ti.field + """ contains Cats ids: - size := cats_n @@ -46,13 +53,29 @@ _F_CAT_PER_CELL: Any -def setup_grid(cat_n: ti.i32, r1: ti.i32, width: ti.i32, height: ti.i32): +def setup_grid( + cat_n: ti.i32, + r1: ti.i32, + width: ti.i32, + height: ti.i32, + fav_cats_amount: ti.i32, + fav_cats_log: bool, +): global _CATS_N, _RADIUS_1, _PLATE_WIDTH, _PLATE_HEIGHT _CATS_N = cat_n _RADIUS_1 = r1 _PLATE_WIDTH = width _PLATE_HEIGHT = height + global _FAV_CATS_AMOUNT, _FAV_CATS_LOGGING, _NEW_LOGS, _OLD_LOGS, _OLD_STATUS + _FAV_CATS_LOGGING = fav_cats_log + _FAV_CATS_AMOUNT = fav_cats_amount + _NEW_LOGS = ti.field(ti.i32, shape=(fav_cats_amount, _CATS_N)) + _OLD_LOGS = ti.field(ti.i32, shape=(fav_cats_amount, _CATS_N)) + _OLD_STATUS = ti.field(ti.i8, shape=(fav_cats_amount,)) + if _FAV_CATS_LOGGING: + _OLD_STATUS.fill(1) + global _CELL_N, _GRID_COL_N, _GRID_ROW_N, _CELL_SZ _CELL_SZ = _RADIUS_1 _GRID_COL_N = math.ceil(_PLATE_WIDTH / _CELL_SZ) @@ -116,6 +139,9 @@ def init_cell_storage(cats: ti.template()): def update_statuses(cats: ti.template()): init_cell_storage(cats) + if _FAV_CATS_LOGGING: + _NEW_LOGS.fill(0) + ti.loop_config(serialize=True) for idx1 in range(_CATS_N): cell_idx = ti.floor(cats[idx1].point / _CELL_SZ, ti.i32) @@ -135,4 +161,30 @@ def update_statuses(cats: ti.template()): idx2 = _F_CELL_STORAGE[_idx2] if idx1 != idx2: - cats[idx1].fight_with(cats[idx2]) + status = cats[idx1].fight_with(cats[idx2]) + + if _FAV_CATS_LOGGING and idx1 < _FAV_CATS_AMOUNT: + _NEW_LOGS[idx1, idx2] = status + + if _FAV_CATS_LOGGING: + for i in range(_FAV_CATS_AMOUNT): + has_interaction = False + for j in range(_CATS_N): + new_val = _NEW_LOGS[i, j] + old_val = _OLD_LOGS[i, j] + if new_val != INTERACTION_NO: + has_interaction = True + if new_val != old_val: + if new_val == INTERACTION_LEVEL_0: + print(f"CAT {j} IS FIGHTING WITH YOUR CAT {i}") + else: + print(f"CAT {j} HISSED AT YOUR CAT {i}") + + _OLD_LOGS[i, j] = new_val + + if not has_interaction: + if _OLD_STATUS[i] == 1: + print(f"YOUR CAT NUMBER {i} IS NOT HURT BY ANYONE") + _OLD_STATUS[i] = 0 + else: + _OLD_STATUS[i] = 1 diff --git a/tests/test_grid.py b/tests/test_grid.py index 4793a49..8dab899 100644 --- a/tests/test_grid.py +++ b/tests/test_grid.py @@ -33,6 +33,8 @@ def test_concrete_case(self): prob_inter=False, border_inter=False, distance_type=EUCLIDEAN_DISTANCE, + fav_cats_observing=False, + observable_angle_span=ti.math.pi / 4, ) points = ti.Vector.field(n=2, dtype=float, shape=(N,)) @@ -50,7 +52,7 @@ def test_concrete_case(self): points=points, ) - setup_grid(N, float(R1), float(WIDTH), float(HEIGHT)) + setup_grid(N, float(R1), float(WIDTH), float(HEIGHT), 1, False) update_statuses(cats) expected_statuses = [ @@ -76,7 +78,14 @@ def test_concrete_case(self): ], ) def test_primitive_func(self, N, R0, R1, RADIUS, WIDTH, HEIGHT, distance_type): - setup_grid(cat_n=N, r1=R1, width=WIDTH, height=HEIGHT) + setup_grid( + cat_n=N, + r1=R1, + width=WIDTH, + height=HEIGHT, + fav_cats_amount=1, + fav_cats_log=False, + ) init_cat_env( move_radius=R0, @@ -88,6 +97,8 @@ def test_primitive_func(self, N, R0, R1, RADIUS, WIDTH, HEIGHT, distance_type): prob_inter=False, border_inter=False, distance_type=distance_type, + fav_cats_observing=False, + observable_angle_span=ti.math.pi / 4, ) cats = Cat.field(shape=(N,)) @@ -100,7 +111,7 @@ def test_primitive_func(self, N, R0, R1, RADIUS, WIDTH, HEIGHT, distance_type): ) spawner.set_cat_init_positions(N, 0) - expected_statuses = ti.ndarray(dtype=ti.i32, shape=(N,)) + expected_statuses = ti.ndarray(dtype=ti.i8, shape=(N,)) primitive_update_states(N, cats, expected_statuses, distance_type, R0, R1) update_statuses(cats) diff --git a/tests/test_move.py b/tests/test_move.py index f6e84c4..4f96ff5 100644 --- a/tests/test_move.py +++ b/tests/test_move.py @@ -26,12 +26,7 @@ def move_line( @pytest.mark.parametrize( "x, y, move_radius", - [ - (10, 50, 8), - (42, 17, 15), - (59, 10, 45), - (1, 2, 10), - ], + [(10, 50, 8), (42, 17, 15), (59, 10, 45), (1, 2, 10)], ) def test_move_radius(self, x: int, y: int, move_radius: int): WIDTH, HEIGHT = 1000, 1000