Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 6e18775

Browse files
committed
fix tests/test_pipeline_benchmark.py
1 parent 7f903a8 commit 6e18775

File tree

1 file changed

+3
-7
lines changed

1 file changed

+3
-7
lines changed

src/deepsparse/benchmark/data_creation.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import random
1818
import string
1919
from os import path
20-
from typing import Dict, List, Tuple
20+
from typing import Dict, List, Tuple, get_args
2121

2222
import numpy
2323

@@ -58,15 +58,11 @@ def get_input_schema_type(pipeline: Pipeline) -> str:
5858
if SchemaType.TEXT_SEQ in input_schema_requirements:
5959
if input_schema_fields.get(SchemaType.TEXT_SEQ).alias == SchemaType.TEXT_PROMPT:
6060
return SchemaType.TEXT_PROMPT
61-
sequence_types = [
62-
f.outer_type_ for f in input_schema_fields[SchemaType.TEXT_SEQ].sub_fields
63-
]
61+
sequence_types = get_args(input_schema_fields[SchemaType.TEXT_SEQ].annotation)
6462
if List[str] in sequence_types:
6563
return SchemaType.TEXT_SEQ
6664
elif SchemaType.TEXT_INPUT in input_schema_requirements:
67-
sequence_types = [
68-
f.outer_type_ for f in input_schema_fields[SchemaType.TEXT_INPUT].sub_fields
69-
]
65+
sequence_types = get_args(input_schema_fields[SchemaType.TEXT_INPUT].annotation)
7066
if List[str] in sequence_types:
7167
return SchemaType.TEXT_INPUT
7268
elif SchemaType.QUESTION in input_schema_requirements:

0 commit comments

Comments
 (0)