diff --git a/example.ipynb b/example.ipynb index 843aa82..671bb4d 100644 --- a/example.ipynb +++ b/example.ipynb @@ -13,8 +13,8 @@ "model = Model()\n", "\n", "# SPF Calculations\n", - "with model.root.add_layer() as layer:\n", - " layer.add_schema({\n", + "with model.cursor.add_layer():\n", + " model.cursor.add_schema({\n", " \"label\": \"schema1\",\n", " \"params\": [\"x\"],\n", " \"actions\": [\"get\"],\n", @@ -26,8 +26,8 @@ " })\n", "\n", " # SPF 1\n", - " with layer.add_sequence() as seq:\n", - " seq.add_schema({\n", + " with model.cursor.add_sequence():\n", + " model.cursor.add_schema({\n", " \"label\": \"schema2\",\n", " \"params\": [\"x\"],\n", " \"actions\": [\"get\"],\n", @@ -37,13 +37,13 @@ " 2: {\"a\": 20, \"b\": 20, \"c\": 20, \"d\": 20}\n", " }\n", " }, hidden=True)\n", - " @seq.add_wrapped()\n", + " @model.cursor.add_wrapped(tags='af')\n", " def process(a, b, c, d):\n", " return a * b + c / d\n", " \n", " # SPF 2\n", - " with layer.add_sequence() as seq:\n", - " seq.add_schema({\n", + " with model.cursor.add_sequence():\n", + " model.cursor.add_schema({\n", " \"label\": \"schema3\",\n", " \"params\": [\"x\"],\n", " \"actions\": [\"get\"],\n", @@ -53,276 +53,135 @@ " 2: {\"a\": 20, \"b\": 20, \"c\": 20, \"d\": 20}\n", " }\n", " }, hidden=True)\n", - " @seq.add_wrapped()\n", + " @model.cursor.add_wrapped(tags='af')\n", " def process(a, b, c, d):\n", " return a * b + c / d\n", "\n", "# AF Calculations\n", - "with model.root.add_layer() as layer:\n", + "with model.cursor.add_layer() as layer:\n", "\n", " # AF 1\n", - " layer.add_wrapped()\n", + " @model.cursor.add_wrapped()\n", " def af_lane_width(lane_width):\n", " return lane_width / 12" ] }, { "cell_type": "code", - "execution_count": 2, - "id": "b5f1c041", + "execution_count": 3, + "id": "19660b10", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "{'x': 1, 'schema1': 10, 'process': 101.0}" + "{'af': [,\n", + " ]}" ] }, - "execution_count": 2, + "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "model.analyze(x=1)" + "model.tagged" ] }, { "cell_type": "code", - "execution_count": null, - "id": "cb22780f", + "execution_count": 8, + "id": "3ad8af3b", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "['lane_width', 'x']" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "model.root.hidden" + "model.params" ] }, { "cell_type": "code", - "execution_count": null, - "id": "5b5a3a05", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3640eff1-d205-46e9-97e2-8e10a19c1372", + "execution_count": 10, + "id": "b5f1c041", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "{'x': 1,\n", + " 'lane_width': 13,\n", + " 'schema1': 10,\n", + " 'process': 101.0,\n", + " 'af_lane_width': 1.0833333333333333}" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "import importlib\n", - "import modelsandbox\n", - "importlib.reload(modelsandbox)\n", - "from modelsandbox import Model" + "model.analyze(x=1, lane_width=13)" ] }, { "cell_type": "code", - "execution_count": null, - "id": "9e8343d5-7941-467c-84a7-9a52b2f7a895", + "execution_count": 5, + "id": "cb22780f", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "['a']" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "from modelsandbox import Model\n", - "\n", - "model = Model()\n", + "def params(f):\n", + " num_args = f.__code__.co_argcount\n", + " return list(f.__code__.co_varnames[:num_args])\n", + "def f(a, *b, **c):\n", + " return\n", "\n", - "model[0].set_label('First Layer')\n", - "model[0].set_tags('layer1')\n", - "model[0].add_function(lambda a, b: a + b, tags='tag1')\n", - "model[0].add_sequence()\n", - "model[0][1]" - ] - }, - { - "cell_type": "markdown", - "id": "f9b134bd-bfb8-46ca-88b1-d3584a512ca6", - "metadata": {}, - "source": [ - "# Cost example" + "params(f)" ] }, { "cell_type": "code", - "execution_count": null, - "id": "0bdd3247-fad7-411d-9806-239895299f6b", + "execution_count": 14, + "id": "5b5a3a05", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "'b'" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "import importlib\n", - "import modelsandbox\n", - "importlib.reload(modelsandbox)\n", - "from modelsandbox import Model\n", + "import inspect\n", "\n", - "# Initialize the model class\n", - "model = Model()\n", - "\n", - "# Add a layer to the model to compute airline ticket cost\n", - "model[0].set_label('Compute ticket cost')\n", - "\n", - "# Add the process schema to the model\n", - "model[0].add_schema({\n", - " \"label\": \"ticket_cost\",\n", - " \"params\": [\"destination\", \"airline_class\"],\n", - " \"actions\": [\"get\", \"get\"],\n", - " \"data\": {\n", - " \"Chicago\": {\n", - " \"Economy\": 220,\n", - " \"Business\": 450,\n", - " \"First\": 785\n", - " },\n", - " \"Los Angeles\": {\n", - " \"Economy\": 365,\n", - " \"Business\": 520,\n", - " \"First\": 965\n", - " }\n", - " }\n", - "})\n", - "\n", - "# Add a layer to the model to compute additional costs\n", - "model[1].set_label('Compute expenses')\n", - "\n", - "# Add a processor to compute travel cost\n", - "@model[1].add_wrapped(tags=['__cost'])\n", - "def travel_cost(number_of_travelers, ticket_cost):\n", - " \"\"\"\n", - " Compute total travel cost for all travelers.\n", - " \"\"\"\n", - " return number_of_travelers * ticket_cost\n", - "\n", - "# Add processor to compute lodging cost\n", - "@model[1].add_wrapped(tags=['__cost'])\n", - "def lodging_cost(number_of_travelers, nightly_cost, number_of_nights):\n", - " \"\"\"\n", - " Compute total lodging cost for all travelers.\n", - " \"\"\"\n", - " return number_of_travelers * nightly_cost * number_of_nights\n", - "\n", - "# Add processor to compute per diem\n", - "@model[1].add_wrapped(tags=['__cost'])\n", - "def per_diem_cost(number_of_travelers, number_of_nights, per_diem):\n", - " \"\"\"\n", - " Compute total per diem cost for all travelers.\n", - " \"\"\"\n", - " return number_of_travelers * number_of_nights * per_diem\n", - "\n", - "# Add a layer to the model to aggregate costs\n", - "model[2].set_label('Aggregate expenses')\n", - "\n", - "# Add processor to compute total trip cost\n", - "@model[2].add_wrapped()\n", - "def total_trip_cost(travel_cost, lodging_cost, per_diem_cost, __cost):\n", - " \"\"\"\n", - " Compute total trip cost for all travelers.\n", - " \"\"\"\n", - " return travel_cost + lodging_cost + per_diem_cost\n", - "\n", - "model.structure" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1eba3535-f095-42be-8a4c-959b0ebe38bd", - "metadata": {}, - "outputs": [], - "source": [ - "model.tagged" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "8e95c00e-ca1a-4c2c-b630-d7ae32dd5fe7", - "metadata": {}, - "outputs": [], - "source": [ - "model.analyze(\n", - " airline_class=\"Business\",\n", - " destination=\"Chicago\",\n", - " nightly_cost=185,\n", - " number_of_nights=4,\n", - " number_of_travelers=3,\n", - " per_diem=72,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "676a253a-10db-4fe2-9634-1c531daa15c4", - "metadata": {}, - "source": [ - "# Scratch" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1c829f1d-5419-4b1a-b42d-bf98f85129fd", - "metadata": {}, - "outputs": [], - "source": [ - "from example import model" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b1c7f912-bb76-4966-998f-ec3979efaf1f", - "metadata": {}, - "outputs": [], - "source": [ - "model.analyze(\n", - " airline_class=\"Business\",\n", - " destination=\"Chicago\",\n", - " nightly_cost=185,\n", - " number_of_nights=4,\n", - " number_of_travelers=3,\n", - " per_diem=72,\n", - " ticket_cost=500\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "740b1f87-61a9-4406-8698-d841f7eaa011", - "metadata": {}, - "outputs": [], - "source": [ - "model.parameters" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "8288bbf1-9a90-4004-8f95-d4557b1fd9bf", - "metadata": {}, - "outputs": [], - "source": [ - "from sample_model import model" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "655db299-aacf-4933-84c5-4e82f0e3ae67", - "metadata": {}, - "outputs": [], - "source": [ - "model.analyze(\n", - " aadt = 580,\n", - " curve_length = 0.2,\n", - " curve_radius = 800,\n", - " lane_width = 11,\n", - " length = 0.6,\n", - " observed_kabco = 3.4,\n", - " shoulder_type = 'gravel',\n", - " shoulder_width = 4,\n", - " spiral = 'both'\n", - ")" + "inspect.getfullargspec(f)[1]" ] } ], diff --git a/hsm/rtl_seg/model.py b/hsm/rtl_seg/model.py index db711ac..6ae3301 100644 --- a/hsm/rtl_seg/model.py +++ b/hsm/rtl_seg/model.py @@ -15,6 +15,7 @@ # Initialize the model class model = Model() +cur = model.cursor # Add data validation model.add_validation(os.path.join(SCHEMA_PATH, 'validation.json')) @@ -28,13 +29,13 @@ # Layer #1: SPF calculations # ----------------------------------------------------------------------------- -with model.root.add_layer() as layer: - with layer.add_sequence() as seq: +with cur.add_layer(): + with cur.add_sequence(): # Add SPF parameters - seq.add_schema(os.path.join(SCHEMA_PATH, 'spf.json'), hidden=True) + cur.add_schema(os.path.join(SCHEMA_PATH, 'spf.json'), hidden=True) # Compute number of crashes - @seq.add_wrapped() + @cur.add_wrapped() def n_kabco(aadt, length): """ Based on HSM Equation 10-7. @@ -44,7 +45,7 @@ def n_kabco(aadt, length): return n # Compute overdispersion - @seq.add_wrapped() + @cur.add_wrapped() def overdispersion(length): # Compute overdispersion k = 0.236 / length @@ -55,9 +56,9 @@ def overdispersion(length): # Layer #2: AF calculations # ----------------------------------------------------------------------------- -with model.root.add_layer() as layer: +with cur.add_layer() as layer: - @layer.add_wrapped() + @cur.add_wrapped() def af_lane_width(lane_width, aadt): # Compute adjustment factor if lane_width < 10: @@ -70,7 +71,7 @@ def af_lane_width(lane_width, aadt): af = 1 return af - @layer.add_wrapped() + @cur.add_wrapped() def af_shoulder_width(shoulder_width, aadt): # Compute adjustment factor if shoulder_width < 2: @@ -85,9 +86,9 @@ def af_shoulder_width(shoulder_width, aadt): af = np.clip(0.98 + 0.6875e-4 * (aadt - 400), 0.98, 0.87) return af - model.add_schema(os.path.join(SCHEMA_PATH, 'af_shoulder_type.json')) + cur.add_schema(os.path.join(SCHEMA_PATH, 'af_shoulder_type.json')) - @layer.add_wrapped() + @cur.add_wrapped() def af_horizontal_curve(curve_length, curve_radius, spiral): # Check if provided if (curve_length == 0) or (curve_radius == 0): @@ -107,7 +108,7 @@ def af_horizontal_curve(curve_length, curve_radius, spiral): ) / (1.55 * curve_length) return af - @layer.add_wrapped() + @cur.add_wrapped() def af_se_variance(se_variance): """ Based on HSM equation 10-14, 10-15, 10-16. @@ -124,7 +125,7 @@ def af_se_variance(se_variance): af = 1.00 + (6 * (se_variance - 0.01)) return af - @layer.add_wrapped() + @cur.add_wrapped() def af_grade(grade): """ Based on HSM table 10-11. @@ -142,7 +143,7 @@ def af_grade(grade): af = 1.16 return af - @layer.add_wrapped() + @cur.add_wrapped() def af_driveway_density(aadt, length, driveway_density): """ Based on HSM equation 10-17 @@ -155,7 +156,7 @@ def af_driveway_density(aadt, length, driveway_density): (0.322 + (5 * (0.05 - 0.005 * math.log(aadt)))) return af - @layer.add_wrapped() + @cur.add_wrapped() def af_rumble_cl(rumble_cl): """ Based on HSM page 10-29 @@ -167,7 +168,7 @@ def af_rumble_cl(rumble_cl): af = 1.00 return af - @layer.add_wrapped() + @cur.add_wrapped() def af_passing_lanes(passing_lanes): """ Based on HSM page 10-29 @@ -181,7 +182,7 @@ def af_passing_lanes(passing_lanes): af = 0.65 return af - @layer.add_wrapped() + @cur.add_wrapped() def af_twltl(twltl, length, driveway_density): """ Based on HSM equation 10-18 and 10-19. @@ -197,7 +198,7 @@ def af_twltl(twltl, length, driveway_density): return af - @layer.add_wrapped() + @cur.add_wrapped() def af_rhr(rhr): """ Based on HSM equation 10-20 and the roadside hazard rating in appendix 13A @@ -206,7 +207,7 @@ def af_rhr(rhr): af = math.exp(-0.6869 + (0.0668 * rhr)) / math.exp(-0.4865) return af - @layer.add_wrapped() + @cur.add_wrapped() def af_lighting(lighting): """ Based on HSM equation 10-21 and table 10-12. @@ -218,7 +219,7 @@ def af_lighting(lighting): af = 1.00 return af - @layer.add_wrapped() + @cur.add_wrapped() def af_ase(ase): """ Based on HSM page 10-31. @@ -229,3 +230,14 @@ def af_ase(ase): else: af = 1.00 return af + +with cur.add_layer(): + @cur.add_wrapped() + def af_total( + af_lane_width, af_shoulder_width, af_horizontal_curve, af_se_variance, + af_grade, af_driveway_density, af_rumble_cl, af_passing_lanes, + af_twltl, af_rhr, af_lighting, af_ase + ): + # Compute adjustment factor + af = af_lane_width * af_shoulder_width * af_horizontal_curve * af_se_variance * af_grade * af_driveway_density * af_rumble_cl * af_passing_lanes * af_twltl * af_rhr * af_lighting * af_ase + return af \ No newline at end of file diff --git a/hsm/rtl_seg/model2.py b/hsm/rtl_seg/model2.py deleted file mode 100644 index 9dc9d09..0000000 --- a/hsm/rtl_seg/model2.py +++ /dev/null @@ -1,231 +0,0 @@ -# ============================================================================= -# LOAD DEPENDENCIES -# ============================================================================= - -from modelsandbox import Model -import os, math -import numpy as np -MODEL_PATH = os.path.abspath(os.path.dirname(__file__)) -SCHEMA_PATH = os.path.join(MODEL_PATH, 'schemas') - - -# ============================================================================= -# PREPARE MODEL -# ============================================================================= - -# Initialize the model class -model = Model() - -# Add data validation -model.add_validation(os.path.join(SCHEMA_PATH, 'validation.json')) - - -# ============================================================================= -# DEFINE MODEL -# ============================================================================= - -# ----------------------------------------------------------------------------- -# Layer #1: SPF calculations -# ----------------------------------------------------------------------------- - -with model.add_layer(): - with model.add_sequence(): - # Add SPF parameters - model.add_schema(os.path.join(SCHEMA_PATH, 'spf.json'), hidden=True) - - # Compute number of crashes - @model.add_wrapped() - def n_kabco(aadt, length): - """ - Based on HSM Equation 10-7. - """ - # Compute number of crashes - n = aadt * length * 365 * 10E-6 * math.exp(-0.312) - return n - - # Compute overdispersion - @model.add_wrapped() - def overdispersion(length): - # Compute overdispersion - k = 0.236 / length - return k - - -# ----------------------------------------------------------------------------- -# Layer #2: AF calculations -# ----------------------------------------------------------------------------- - -with model.root.add_layer() as layer: - - @layer.add_wrapped() - def af_lane_width(lane_width, aadt): - # Compute adjustment factor - if lane_width < 10: - af = np.clip(1.05 + 2.81e-4 * (aadt - 400), 1.05, 1.50) - elif lane_width < 11: - af = np.clip(1.02 + 1.75e-4 * (aadt - 400), 1.02, 1.30) - elif lane_width < 12: - af = np.clip(1.01 + 0.25e-4 * (aadt - 400), 1.01, 1.05) - else: - af = 1 - return af - - @layer.add_wrapped() - def af_shoulder_width(shoulder_width, aadt): - # Compute adjustment factor - if shoulder_width < 2: - af = np.clip(1.10 + 2.50e-4 * (aadt - 400), 1.10, 1.50) - elif shoulder_width < 4: - af = np.clip(1.07 + 1.43e-4 * (aadt - 400), 1.07, 1.30) - elif shoulder_width < 6: - af = np.clip(1.02 + 0.8125e-4 * (aadt - 400), 1.02, 1.15) - elif shoulder_width < 8: - af = 1 - else: - af = np.clip(0.98 + 0.6875e-4 * (aadt - 400), 0.98, 0.87) - return af - - model.add_schema(os.path.join(SCHEMA_PATH, 'af_shoulder_type.json')) - - @layer.add_wrapped() - def af_horizontal_curve(curve_length, curve_radius, spiral): - # Check if provided - if (curve_length == 0) or (curve_radius == 0): - return 1.0 - # Process spiral information - spiral_value = {'both': 1.0, 'one': 0.5, 'neither': 0.0}[spiral] - # Clip values - # - Minimum of 100' length if provided - curve_length = max(curve_length, 100 / 5280) - # - Minimum of 100' radius if provided - curve_radius = max(curve_radius, 100) - # Compute adjustment factor - af = ( - (1.55 * curve_length) + \ - (80.2 / curve_radius) - \ - (0.012 * spiral_value) - ) / (1.55 * curve_length) - return af - - @layer.add_wrapped() - def af_se_variance(se_variance): - """ - Based on HSM equation 10-14, 10-15, 10-16. - - NOTE: Future improvement, code AASHTO SE Tables to automatically calculate - variance from input superelevation and other values. - """ - # Compute adjustment factor - if se_variance < 0.01: - af = 1.00 - elif se_variance >= 0.02: - af = 1.06 + (3 * (se_variance - 0.02)) - else: - af = 1.00 + (6 * (se_variance - 0.01)) - return af - - @layer.add_wrapped() - def af_grade(grade): - """ - Based on HSM table 10-11. - """ - # Enforce positive grade - grade = math.fabs(grade) - # "Level Grade" - if grade <= 3.00: - af = 1.00 - # "Moderate Terrain" - elif grade <= 6.00: - af = 1.10 - # "Steep Terrain" - else: - af = 1.16 - return af - - @layer.add_wrapped() - def af_driveway_density(aadt, length, driveway_density): - """ - Based on HSM equation 10-17 - """ - # Enforce minimum driveway number - if driveway_density < 5: - af = 1.00 - else: - af = (0.322 + (driveway_density * (0.05 - 0.005 * math.log(aadt)))) / \ - (0.322 + (5 * (0.05 - 0.005 * math.log(aadt)))) - return af - - @layer.add_wrapped() - def af_rumble_cl(rumble_cl): - """ - Based on HSM page 10-29 - """ - # Compute binary adjustment factor - if rumble_cl == 1: - af = 0.94 - else: - af = 1.00 - return af - - @layer.add_wrapped() - def af_passing_lanes(passing_lanes): - """ - Based on HSM page 10-29 - """ - # Compute adjustment factor based on number of passing lanes present - if passing_lanes == 0: - af = 1.00 - elif passing_lanes == 1: - af = 0.75 - elif passing_lanes == 2: - af = 0.65 - return af - - @layer.add_wrapped() - def af_twltl(twltl, length, driveway_density): - """ - Based on HSM equation 10-18 and 10-19. - """ - if twltl == 0: - af = 1.00 - elif driveway_density < 5: - af = 1.00 - else: - driveway_prop = ((0.0047 * driveway_density) + (0.0024 * driveway_density ** 2)) / \ - (1.199 + (0.0047 * driveway_density) + (0.0024 * driveway_density ** 2)) - af = 1.0 - (0.7 * driveway_prop * 0.5) - return af - - - @layer.add_wrapped() - def af_rhr(rhr): - """ - Based on HSM equation 10-20 and the roadside hazard rating in appendix 13A - """ - # Compute adjustment factor - af = math.exp(-0.6869 + (0.0668 * rhr)) / math.exp(-0.4865) - return af - - @layer.add_wrapped() - def af_lighting(lighting): - """ - Based on HSM equation 10-21 and table 10-12. - """ - # Compute adjustment factor - if lighting == 1: - af = 1.0 - ((1.0 - (0.72 * 0.382) - (0.83 * 0.618)) * 0.370) - else: - af = 1.00 - return af - - @layer.add_wrapped() - def af_ase(ase): - """ - Based on HSM page 10-31. - """ - # Compute adjustment factor - if ase == 1: - af = 0.93 - else: - af = 1.00 - return af diff --git a/modelsandbox/model/containers.py b/modelsandbox/model/containers.py index 7f508ce..0d9549f 100644 --- a/modelsandbox/model/containers.py +++ b/modelsandbox/model/containers.py @@ -4,7 +4,7 @@ from modelsandbox.model.processors import FunctionProcessor, SchemaProcessor, EmptyProcessor -class ContainerAdder(object): +class ComponentAdder(object): """ Class for adding model components to a container. """ @@ -43,7 +43,7 @@ def wrapper(func) -> FunctionProcessor: return wrapper -class Layer(BaseContainer, ContainerAdder): +class Layer(BaseContainer, ComponentAdder): """ Model layer which processes members concurrently. """ @@ -87,7 +87,7 @@ def analyze(self, **params) -> dict: return returns -class Sequence(BaseContainer, ContainerAdder): +class Sequence(BaseContainer, ComponentAdder): """ Model sequence which processes members sequentially. """ @@ -131,5 +131,5 @@ def analyze(self, **params) -> dict: return returns -Layer._valid_member_types = (Sequence, BaseProcessor) -Sequence._valid_member_types = (Layer, BaseProcessor) +Layer._valid_member_types = (Layer, Sequence, BaseProcessor) +Sequence._valid_member_types = (Layer, Sequence, BaseProcessor) diff --git a/modelsandbox/model/model.py b/modelsandbox/model/model.py index 7b7682b..f61037c 100644 --- a/modelsandbox/model/model.py +++ b/modelsandbox/model/model.py @@ -1,4 +1,7 @@ +from __future__ import annotations from modelsandbox.model.containers import Sequence, Layer +from modelsandbox.model.processors import FunctionProcessor, SchemaProcessor, EmptyProcessor +from modelsandbox.model.base import BaseContainer class Model(object): @@ -8,6 +11,7 @@ class Model(object): def __init__(self, **kwargs): self._root = Sequence(**kwargs) + self._cursor = Cursor(self) def __getitem__(self, index): return self._root[index] @@ -24,6 +28,13 @@ def root(self): """ return self._root + @property + def cursor(self): + """ + Return the cursor for the model. + """ + return self._cursor + @property def params(self): """ @@ -78,4 +89,200 @@ def analyze(self, **params): """ Execute the model. """ - return self._root.analyze(**params) \ No newline at end of file + return self._root.analyze(**params) + + +class Cursor(object): + """ + Class for managing and accessing indexed positions within a model. + """ + + def __init__(self, model: Model): + self.model = model + self.reset() + + def __getitem__(self, index): + """ + Return the model component at the specified index. + """ + return self.get_index(index) + + def __enter__(self): + # Get the currently indexed component + component = self.get_current() + # If the component is a container, attempt to enter it + if isinstance(component, BaseContainer): + if len(component) == 0: + raise IndexError( + "Cannot enter an empty container. Add a new sub-container " + "with the 'add_layer' or 'add_sequence' methods.") + self._index.append(len(component) - 1) + else: + raise TypeError("Cannot enter a non-container component.") + return self.get_current() + + def __exit__(self, exc_type, exc_value, traceback): + # Move down one level in the model cursor + self._index.pop() + + @property + def model(self): + """ + Return the model associated with the cursor. + """ + return self._model + + @model.setter + def model(self, obj: Model): + """ + Set the model associated with the cursor. + """ + if not isinstance(obj, Model): + raise TypeError("Model must be an instance of 'Model'.") + self._model = obj + + @property + def root(self): + """ + Return the root of the model. + """ + return self.model.root + + def _validate_index(self, index: tuple[list, tuple]): + """ + Validate the index to ensure it is a valid path within the model. + """ + if not isinstance(index, (list, tuple, int)): + raise TypeError("Index must be a list, tuple, or integer.") + if isinstance(index, int): + return [index] + if not all(isinstance(i, int) for i in index): + raise TypeError("Index must contain only integers.") + return list(index) + + def reset(self): + """ + Reset the cursor to the root of the model. + """ + self._index = [] # Point to the root of the model + + def _get_from_root(self, index: tuple[list, tuple]): + """ + Traverse the model root to get the indexed component. + """ + # Start at the model root + container = self._model.root + return self._get_from_container(container, index) + + @classmethod + def _get_from_container(cls, container: BaseContainer, index: tuple[list, tuple]): + # Iterate over indices to traverse components + component = container + for i, j in enumerate(index): + try: + component = component[j] + except: + raise IndexError(f"Index {j} at position {i} is out of range.") + return component + + def get_current(self): + """ + Return the current model component. + """ + return self._get_from_root(self._index) + + def get_index(self, index: tuple[list, tuple]): + """ + Return the model component at the specified index relative to the + current cursor position. + """ + return self._get_from_root(self._index + self._validate_index(index)) + + def add_layer(self, *args, **kwargs) -> Cursor: + """ + Add a new model layer to the indexed model container. + + Returns + ------- + This method returns itself to allow for chaining and the use of the + with statement syntax. + """ + # Get current component + container = self.get_current() + # Confirm valid container + if not isinstance(container, BaseContainer): + raise TypeError("Cannot add component to a non-container component.") + # Add member + container.add_layer(*args, **kwargs) + return self + + def add_sequence(self, *args, **kwargs) -> Cursor: + """ + Add a new model sequence to the indexed model layer. + + Returns + ------- + This method returns itself to allow for chaining and the use of the + with statement syntax. + """ + # Get current component + container = self.get_current() + # Confirm valid container + if not isinstance(container, BaseContainer): + raise TypeError("Cannot add component to a non-container component.") + # Add member + container.add_sequence(*args, **kwargs) + return self + + def add_function(self, *args, **kwargs) -> Cursor: + """ + Add a new model function to the indexed model layer. + + Returns + ------- + This method returns itself to allow for chaining and the use of the + with statement syntax. + """ + # Get current component + container = self.get_current() + # Confirm valid container + if not isinstance(container, BaseContainer): + raise TypeError("Cannot add component to a non-container component.") + # Add member + container.add_function(*args, **kwargs) + return self + + def add_schema(self, *args, **kwargs) -> Cursor: + """ + Add a new model schema to the indexed model layer. + + Returns + ------- + This method returns itself to allow for chaining and the use of the + with statement syntax. + """ + # Get current component + container = self.get_current() + # Confirm valid container + if not isinstance(container, BaseContainer): + raise TypeError("Cannot add component to a non-container component.") + # Add member + container.add_schema(*args, **kwargs) + return self + + def add_wrapped(self, *args, **kwargs) -> callable: + """ + Add a new model function to the indexed model layer by returning a + decorator to wrap the function. + + Returns + ------- + This method returns a decorator to wrap the function. + """ + # Get current component + container = self.get_current() + # Confirm valid container + if not isinstance(container, BaseContainer): + raise TypeError("Cannot add component to a non-container component.") + # Add member + return container.add_wrapped(*args, **kwargs) \ No newline at end of file