Skip to content

Commit e251621

Browse files
[SYCL][E2E] Fix infinite loop bug in Config/select_device.cpp (#12814)
1 parent db09873 commit e251621

File tree

1 file changed

+47
-75
lines changed

1 file changed

+47
-75
lines changed

sycl/test-e2e/Config/select_device.cpp

Lines changed: 47 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11
// REQUIRES: gpu
2-
// Post-commit fails due to a bug in test, will fix in a couple of days.
3-
// UNSUPPORTED: gpu-intel-dg2
42
// RUN: %{build} -o %t.out
53
//
64
// RUN: env ONEAPI_DEVICE_SELECTOR="*:gpu" %{run-unfiltered-devices} %t.out DEVICE_INFO write > %t.txt
@@ -86,92 +84,66 @@ static void addEscapeSymbolToSpecialCharacters(std::string &str) {
8684
}
8785
}
8886

89-
static std::vector<DevDescT> getAllowListDesc(std::string allowList) {
87+
static std::vector<DevDescT> getAllowListDesc(std::string_view allowList) {
9088
if (allowList.empty())
9189
return {};
9290

93-
std::string deviceName("DeviceName:");
94-
std::string driverVersion("DriverVersion:");
95-
std::string platformName("PlatformName:");
96-
std::string platformVersion("PlatformVersion:");
9791
std::vector<DevDescT> decDescs;
9892
decDescs.emplace_back();
9993

100-
size_t pos = 0;
101-
while (pos < allowList.size()) {
102-
if ((allowList.compare(pos, deviceName.size(), deviceName)) == 0) {
103-
if ((pos = allowList.find("{{", pos)) == std::string::npos) {
104-
throw std::runtime_error("Malformed device allowlist");
105-
}
106-
size_t start = pos + 2;
107-
if ((pos = allowList.find("}}", pos)) == std::string::npos) {
108-
throw std::runtime_error("Malformed device allowlist");
109-
}
110-
decDescs.back().devName = allowList.substr(start, pos - start);
111-
pos = pos + 2;
94+
auto try_parse = [&](std::string_view str) -> std::optional<std::string> {
95+
// std::string_view::starts_with is C++20.
96+
if (allowList.compare(0, str.size(), str) != 0)
97+
return {};
11298

113-
if (allowList[pos] == ',') {
114-
pos++;
115-
}
116-
}
99+
allowList.remove_prefix(str.size());
117100

118-
else if ((allowList.compare(pos, driverVersion.size(), driverVersion)) ==
119-
0) {
120-
if ((pos = allowList.find("{{", pos)) == std::string::npos) {
121-
throw std::runtime_error("Malformed device allowlist");
122-
}
123-
size_t start = pos + 2;
124-
if ((pos = allowList.find("}}", pos)) == std::string::npos) {
125-
throw std::runtime_error("Malformed device allowlist");
126-
}
127-
decDescs.back().devDriverVer = allowList.substr(start, pos - start);
128-
pos = pos + 2;
101+
using namespace std::string_literals;
102+
auto pattern_start = allowList.find("{{");
103+
if (pattern_start == std::string::npos)
104+
throw std::runtime_error("Malformed "s + std::string{str} + " allowlist"s);
129105

130-
if (allowList[pos] == ',') {
131-
pos++;
132-
}
133-
}
106+
allowList.remove_prefix(pattern_start + 2);
107+
auto pattern_end = allowList.find("}}");
108+
if (pattern_end == std::string::npos)
109+
throw std::runtime_error("Malformed "s + std::string{str} + " allowlist"s);
134110

135-
else if ((allowList.compare(pos, platformName.size(), platformName)) == 0) {
136-
if ((pos = allowList.find("{{", pos)) == std::string::npos) {
137-
throw std::runtime_error("Malformed platform allowlist");
138-
}
139-
size_t start = pos + 2;
140-
if ((pos = allowList.find("}}", pos)) == std::string::npos) {
141-
throw std::runtime_error("Malformed platform allowlist");
142-
}
143-
decDescs.back().platName = allowList.substr(start, pos - start);
144-
pos = pos + 2;
145-
if (allowList[pos] == ',') {
146-
pos++;
147-
}
148-
}
111+
auto result = allowList.substr(0, pattern_end);
112+
allowList.remove_prefix(pattern_end + 2);
149113

150-
else if ((allowList.compare(pos, platformVersion.size(),
151-
platformVersion)) == 0) {
152-
if ((pos = allowList.find("{{", pos)) == std::string::npos) {
153-
throw std::runtime_error("Malformed platform allowlist");
154-
}
155-
size_t start = pos + 2;
156-
if ((pos = allowList.find("}}", pos)) == std::string::npos) {
157-
throw std::runtime_error("Malformed platform allowlist");
158-
}
159-
decDescs.back().platVer = allowList.substr(start, pos - start);
160-
pos = pos + 2;
161-
}
114+
if (allowList[0] == ',')
115+
allowList.remove_prefix(1);
116+
return {std::string{result}};
117+
};
162118

163-
else if (allowList.find('|', pos) != std::string::npos) {
164-
// FIXME: That is wrong and result in a infinite loop. We start processing
165-
// the string from the start here.
166-
pos = allowList.find('|') + 1;
167-
while (allowList[pos] == ' ') {
168-
pos++;
169-
}
170-
decDescs.emplace_back();
171-
} else {
172-
throw std::runtime_error("Malformed platform allowlist");
119+
while (!allowList.empty()) {
120+
if (auto pattern = try_parse("DeviceName:")) {
121+
decDescs.back().devName = *pattern;
122+
continue;
123+
}
124+
if (auto pattern = try_parse("DriverVersion:")) {
125+
decDescs.back().devDriverVer = *pattern;
126+
continue;
127+
}
128+
if (auto pattern = try_parse("PlatformName:")) {
129+
decDescs.back().platName = *pattern;
130+
continue;
131+
}
132+
if (auto pattern = try_parse("PlatformVersion:")) {
133+
decDescs.back().platVer = *pattern;
134+
continue;
173135
}
174-
} // while (pos <= allowList.size())
136+
137+
auto next = allowList.find('|');
138+
if (next == std::string::npos)
139+
throw std::runtime_error("Malformed allowlist");
140+
allowList.remove_prefix(next + 1);
141+
142+
auto non_space = allowList.find_first_not_of(" ");
143+
allowList.remove_prefix(non_space);
144+
decDescs.emplace_back();
145+
}
146+
175147
return decDescs;
176148
}
177149

0 commit comments

Comments
 (0)