Skip to content

Commit c336e5a

Browse files
Added check for registering metric providers after prepare() has been called.
1 parent 46bac02 commit c336e5a

File tree

2 files changed

+110
-8
lines changed

2 files changed

+110
-8
lines changed

stars-core/src/main/kotlin/tools/aqua/stars/core/evaluation/TSCEvaluation.kt

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -80,22 +80,25 @@ class TSCEvaluation<E : EntityType<E, T, S>, T : TickDataType<E, T, S>, S : Segm
8080
* Registers new [MetricProvider]s to the list of metrics that should be called during evaluation.
8181
*
8282
* @param metricProviders The [MetricProvider]s that should be registered.
83+
* @throws IllegalArgumentException If [prepare] has already been called.
8384
*/
8485
fun registerMetricProviders(vararg metricProviders: MetricProvider<E, T, S>) {
85-
evaluationMetrics.register(
86-
metricProviders.filterIsInstance<EvaluationMetricProvider<E, T, S>>())
87-
postEvaluationMetrics.register(
88-
metricProviders.filterIsInstance<PostEvaluationMetricProvider<E, T, S>>())
86+
registerEvaluationMetricProviders(
87+
metricProviders.filterIsInstance<EvaluationMetricProvider<E, T, S>>().toList())
88+
registerPostEvaluationMetricProviders(
89+
metricProviders.filterIsInstance<PostEvaluationMetricProvider<E, T, S>>().toList())
8990
}
9091

9192
/**
9293
* Registers new [EvaluationMetricProvider]s to the list of metrics that should be called during
9394
* evaluation.
9495
*
9596
* @param metricProviders The [EvaluationMetricProvider]s that should be registered.
97+
* @throws IllegalArgumentException If [prepare] has already been called.
9698
*/
97-
fun registerEvaluationMetricProviders(vararg metricProviders: EvaluationMetricProvider<E, T, S>) {
98-
this.evaluationMetrics.register(metricProviders.toList())
99+
fun registerEvaluationMetricProviders(metricProviders: List<EvaluationMetricProvider<E, T, S>>) {
100+
check(tscProjections.isEmpty()) { "TSCEvaluation.prepare() has already been called." }
101+
this.evaluationMetrics.register(metricProviders)
99102
}
100103

101104
/**
@@ -105,9 +108,9 @@ class TSCEvaluation<E : EntityType<E, T, S>, T : TickDataType<E, T, S>, S : Segm
105108
* @param metricProviders The [PostEvaluationMetricProvider]s that should be registered.
106109
*/
107110
fun registerPostEvaluationMetricProviders(
108-
vararg metricProviders: PostEvaluationMetricProvider<E, T, S>
111+
metricProviders: List<PostEvaluationMetricProvider<E, T, S>>
109112
) {
110-
this.postEvaluationMetrics.register(metricProviders.toList())
113+
this.postEvaluationMetrics.register(metricProviders)
111114
}
112115

113116
/**
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
/*
2+
* Copyright 2023 The STARS Project Authors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package tools.aqua.stars.core.evaluation
19+
20+
import kotlin.test.BeforeTest
21+
import kotlin.test.Test
22+
import kotlin.test.assertFailsWith
23+
import tools.aqua.stars.core.metric.metrics.evaluation.SegmentCountMetric
24+
import tools.aqua.stars.core.tsc.TSC
25+
import tools.aqua.stars.core.tsc.builder.*
26+
import tools.aqua.stars.core.tsc.projection.proj
27+
import tools.aqua.stars.core.tsc.projection.projRec
28+
import tools.aqua.stars.core.types.EntityType
29+
import tools.aqua.stars.core.types.SegmentType
30+
import tools.aqua.stars.core.types.TickDataType
31+
32+
/** Tests for list extension functions. */
33+
class TSCEvaluationTest {
34+
35+
/** Placeholder [EntityType] type. */
36+
class EType(override val id: Int, override val tickData: TType) : EntityType<EType, TType, SType>
37+
38+
/** Placeholder [TickDataType] type. */
39+
class TType(
40+
override val currentTick: Double,
41+
override var entities: List<EType>,
42+
override var segment: SType
43+
) : TickDataType<EType, TType, SType>
44+
45+
/** Placeholder [SegmentType] type. */
46+
class SType(
47+
override val tickData: List<TType>,
48+
override val ticks: Map<Double, TType>,
49+
override val tickIDs: List<Double>,
50+
override val segmentSource: String,
51+
override val firstTickId: Double,
52+
override val primaryEntityId: Int
53+
) : SegmentType<EType, TType, SType>
54+
55+
private lateinit var tscEvaluation: TSCEvaluation<EType, TType, SType>
56+
57+
/** Sets up a [TSC] and [TSCEvaluation]. */
58+
@BeforeTest
59+
fun prepare() {
60+
val tsc =
61+
TSC(
62+
root<EType, TType, SType> {
63+
all("TSCRoot") {
64+
valueFunction = { "TSCRoot" }
65+
projectionIDs = mapOf(projRec("all"), proj("projection"))
66+
exclusive("Weather") {
67+
projectionIDs = mapOf(projRec("projection"))
68+
leaf("Truth") { condition = { _ -> true } }
69+
leaf("Falseness") { condition = { _ -> false } }
70+
}
71+
}
72+
})
73+
tscEvaluation = TSCEvaluation(tsc = tsc, numThreads = 1)
74+
}
75+
76+
/** Tests calling prepare() without registered metric provider. */
77+
@Test
78+
fun testPrepareWithoutMetricProvider() {
79+
assertFailsWith<IllegalStateException> { tscEvaluation.prepare() }
80+
}
81+
82+
/** Tests calling prepare() multiple times. */
83+
@Test
84+
fun testMultiplePrepare() {
85+
tscEvaluation.registerMetricProviders(SegmentCountMetric())
86+
tscEvaluation.prepare()
87+
assertFailsWith<IllegalStateException> { tscEvaluation.prepare() }
88+
}
89+
90+
/** Tests registerMetricProvider() after calling prepare(). */
91+
@Test
92+
fun testRegisterMetricProvidersAfterPrepare() {
93+
tscEvaluation.registerMetricProviders(SegmentCountMetric())
94+
tscEvaluation.prepare()
95+
assertFailsWith<IllegalStateException> {
96+
tscEvaluation.registerMetricProviders(SegmentCountMetric())
97+
}
98+
}
99+
}

0 commit comments

Comments
 (0)