1
1
2
2
from typing import Type , Dict , Optional , List , Tuple , Any , Union
3
- from pydantic import BaseModel , confloat
3
+ from pydantic import BaseModel , confloat , Field
4
+ from label_studio_sdk .label_interface .objects import PredictionValue
5
+ from typing import Union , List
4
6
5
- from label_studio_sdk .objects import PredictionValue
7
+
8
+ # one or multiple predictions per task
9
+ SingleTaskPredictions = Union [List [PredictionValue ], PredictionValue ]
6
10
7
11
8
12
class ModelResponse (BaseModel ):
9
13
"""
10
14
"""
11
15
model_version : Optional [str ] = None
12
- predictions : List [PredictionValue ]
16
+ predictions : List [SingleTaskPredictions ]
13
17
14
18
def has_model_version (self ) -> bool :
15
19
return bool (self .model_version )
@@ -18,21 +22,16 @@ def update_predictions_version(self) -> None:
18
22
"""
19
23
"""
20
24
for prediction in self .predictions :
21
- if not prediction .model_version :
22
- prediction .model_version = self .model_version
25
+ if isinstance (prediction , PredictionValue ):
26
+ prediction = [prediction ]
27
+ for p in prediction :
28
+ if not p .model_version :
29
+ p .model_version = self .model_version
23
30
24
31
def set_version (self , version : str ) -> None :
25
32
"""
26
33
"""
27
34
self .model_version = version
28
35
# Set the version for each prediction
29
36
self .update_predictions_version ()
30
-
31
- def serialize (self ):
32
- """
33
- """
34
- return {
35
- "model_version" : self .model_version ,
36
- "predictions" : [ p .serialize () for p in self .predictions ]
37
- }
38
37
0 commit comments