Skip to content

Commit 3de8bac

Browse files
committed
fix parseDisjointPoolConfig and add tests
1 parent 222e4b1 commit 3de8bac

File tree

2 files changed

+68
-36
lines changed

2 files changed

+68
-36
lines changed

source/common/umf_pools/disjoint_pool_config_parser.cpp

Lines changed: 21 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -177,46 +177,31 @@ DisjointPoolAllConfigs parseDisjointPoolConfig(const std::string &config,
177177
};
178178

179179
size_t MaxSize = (std::numeric_limits<size_t>::max)();
180+
size_t EnableBuffers = 1;
180181

181182
// Update pool settings if specified in environment.
182-
size_t EnableBuffers = 1;
183-
if (config != "") {
184-
std::string Params = config;
185-
size_t Pos = Params.find(';');
186-
if (Pos != std::string::npos) {
187-
if (Pos > 0) {
188-
GetValue(Params, Pos, EnableBuffers);
189-
}
190-
Params.erase(0, Pos + 1);
191-
size_t Pos = Params.find(';');
192-
if (Pos != std::string::npos) {
193-
if (Pos > 0) {
194-
GetValue(Params, Pos, MaxSize);
195-
}
196-
Params.erase(0, Pos + 1);
197-
do {
198-
size_t Pos = Params.find(';');
199-
if (Pos != std::string::npos) {
200-
if (Pos > 0) {
201-
std::string MemParams = Params.substr(0, Pos);
202-
MemTypeParser(MemParams);
203-
}
204-
Params.erase(0, Pos + 1);
205-
if (Params.size() == 0) {
206-
break;
207-
}
208-
} else {
209-
MemTypeParser(Params);
210-
break;
211-
}
212-
} while (true);
213-
} else {
214-
// set MaxPoolSize for all configs
215-
GetValue(Params, Params.size(), MaxSize);
216-
}
183+
bool EnableBuffersSet = false;
184+
bool MaxSizeSet = false;
185+
size_t Start = 0;
186+
size_t End = config.find(';');
187+
while (true) {
188+
std::string Param = config.substr(Start, End - Start);
189+
if (!EnableBuffersSet && isdigit(Param[0])) {
190+
GetValue(Param, Param.size(), EnableBuffers);
191+
EnableBuffersSet = true;
192+
} else if (!MaxSizeSet && isdigit(Param[0])) {
193+
GetValue(Param, Param.size(), MaxSize);
194+
MaxSizeSet = true;
217195
} else {
218-
GetValue(Params, Params.size(), EnableBuffers);
196+
MemTypeParser(Param);
219197
}
198+
199+
if (End == std::string::npos) {
200+
break;
201+
}
202+
203+
Start = End + 1;
204+
End = config.find(';', Start);
220205
}
221206

222207
AllConfigs.EnableBuffers = EnableBuffers;

test/usm/usmPoolManager.cpp

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
// See LICENSE.TXT
44
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
55

6+
#include "umf_pools/disjoint_pool_config_parser.hpp"
67
#include "ur_pool_manager.hpp"
78

89
#include <uur/fixtures.h>
@@ -18,6 +19,27 @@ auto createMockPoolHandle() {
1819
[](umf_memory_pool_t *) {});
1920
}
2021

22+
bool compareConfig(const usm::umf_disjoint_pool_config_t &left,
23+
usm::umf_disjoint_pool_config_t &right) {
24+
return left.MaxPoolableSize == right.MaxPoolableSize &&
25+
left.Capacity == right.Capacity &&
26+
left.SlabMinSize == right.SlabMinSize;
27+
}
28+
29+
bool compareConfigs(const usm::DisjointPoolAllConfigs &left,
30+
usm::DisjointPoolAllConfigs &right) {
31+
return left.EnableBuffers == right.EnableBuffers &&
32+
compareConfig(left.Configs[usm::DisjointPoolMemType::Host],
33+
right.Configs[usm::DisjointPoolMemType::Host]) &&
34+
compareConfig(left.Configs[usm::DisjointPoolMemType::Device],
35+
right.Configs[usm::DisjointPoolMemType::Device]) &&
36+
compareConfig(left.Configs[usm::DisjointPoolMemType::Shared],
37+
right.Configs[usm::DisjointPoolMemType::Shared]) &&
38+
compareConfig(
39+
left.Configs[usm::DisjointPoolMemType::SharedReadOnly],
40+
right.Configs[usm::DisjointPoolMemType::SharedReadOnly]);
41+
}
42+
2143
TEST_P(urUsmPoolDescriptorTest, poolIsPerContextTypeAndDevice) {
2244
auto &devices = uur::DevicesEnvironment::instance->devices;
2345

@@ -111,4 +133,29 @@ TEST_P(urUsmPoolManagerTest, poolManagerGetNonexistant) {
111133
}
112134
}
113135

136+
TEST_P(urUsmPoolManagerTest, config) {
137+
// Check default config
138+
usm::DisjointPoolAllConfigs def;
139+
usm::DisjointPoolAllConfigs parsed1 =
140+
usm::parseDisjointPoolConfig("1;host:2M,4,64K;device:4M,4,64K;"
141+
"shared:0,0,2M;read_only_shared:4M,4,2M",
142+
0);
143+
ASSERT_EQ(compareConfigs(def, parsed1), true);
144+
145+
// Check partially set config
146+
usm::DisjointPoolAllConfigs parsed2 =
147+
usm::parseDisjointPoolConfig("1;device:4M;shared:0,0,2M", 0);
148+
ASSERT_EQ(compareConfigs(def, parsed2), true);
149+
150+
// Check non-default config
151+
usm::DisjointPoolAllConfigs test(def);
152+
test.Configs[usm::DisjointPoolMemType::Shared].MaxPoolableSize = 128 * 1024;
153+
test.Configs[usm::DisjointPoolMemType::Shared].Capacity = 4;
154+
test.Configs[usm::DisjointPoolMemType::Shared].SlabMinSize = 64 * 1024;
155+
156+
usm::DisjointPoolAllConfigs parsed3 =
157+
usm::parseDisjointPoolConfig("1;shared:128K,4,64K", 0);
158+
ASSERT_EQ(compareConfigs(test, parsed3), true);
159+
}
160+
114161
UUR_INSTANTIATE_DEVICE_TEST_SUITE_P(urUsmPoolManagerTest);

0 commit comments

Comments
 (0)