Skip to content

Commit 51dbaa4

Browse files
authored
Merge pull request #3311 from stan-dev/bugfix/3301-stan-csv-reader
Bugfix/3301 stan csv reader
2 parents 6773989 + 24cfc7e commit 51dbaa4

File tree

7 files changed

+4392
-72
lines changed

7 files changed

+4392
-72
lines changed

src/stan/io/stan_csv_reader.hpp

Lines changed: 58 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,6 @@ inline void prettify_stan_csv_name(std::string& variable) {
2929
}
3030
}
3131

32-
// FIXME: should consolidate with the options from
33-
// the command line in stan::lang
3432
struct stan_csv_metadata {
3533
int stan_version_major;
3634
int stan_version_minor;
@@ -47,6 +45,7 @@ struct stan_csv_metadata {
4745
bool save_warmup;
4846
size_t thin;
4947
bool append_samples;
48+
std::string method;
5049
std::string algorithm;
5150
std::string engine;
5251
int max_depth;
@@ -64,8 +63,9 @@ struct stan_csv_metadata {
6463
num_samples(0),
6564
num_warmup(0),
6665
save_warmup(false),
67-
thin(0),
66+
thin(1),
6867
append_samples(false),
68+
method(""),
6969
algorithm(""),
7070
engine(""),
7171
max_depth(10) {}
@@ -101,13 +101,12 @@ class stan_csv_reader {
101101
stan_csv_reader() {}
102102
~stan_csv_reader() {}
103103

104-
static bool read_metadata(std::istream& in, stan_csv_metadata& metadata,
105-
std::ostream* out) {
104+
static void read_metadata(std::istream& in, stan_csv_metadata& metadata) {
106105
std::stringstream ss;
107106
std::string line;
108107

109108
if (in.peek() != '#')
110-
return false;
109+
return;
111110
while (in.peek() == '#') {
112111
std::getline(in, line);
113112
ss << line << '\n';
@@ -161,9 +160,15 @@ class stan_csv_reader {
161160
metadata.model = value;
162161
} else if (name.compare("num_samples") == 0) {
163162
std::stringstream(value) >> metadata.num_samples;
163+
} else if (name.compare("output_samples") == 0) { // ADVI config name
164+
std::stringstream(value) >> metadata.num_samples;
164165
} else if (name.compare("num_warmup") == 0) {
165166
std::stringstream(value) >> metadata.num_warmup;
166167
} else if (name.compare("save_warmup") == 0) {
168+
// cmdstan args can be "true" and "false", was "1", "0"
169+
if (value.compare("true") == 0) {
170+
value = "1";
171+
}
167172
std::stringstream(value) >> metadata.save_warmup;
168173
} else if (name.compare("thin") == 0) {
169174
std::stringstream(value) >> metadata.thin;
@@ -177,6 +182,8 @@ class stan_csv_reader {
177182
metadata.random_seed = false;
178183
} else if (name.compare("append_samples") == 0) {
179184
std::stringstream(value) >> metadata.append_samples;
185+
} else if (name.compare("method") == 0) {
186+
metadata.method = value;
180187
} else if (name.compare("algorithm") == 0) {
181188
metadata.algorithm = value;
182189
} else if (name.compare("engine") == 0) {
@@ -185,14 +192,10 @@ class stan_csv_reader {
185192
std::stringstream(value) >> metadata.max_depth;
186193
}
187194
}
188-
if (ss.good() == true)
189-
return false;
190-
191-
return true;
192195
} // read_metadata
193196

194197
static bool read_header(std::istream& in, std::vector<std::string>& header,
195-
std::ostream* out, bool prettify_name = true) {
198+
bool prettify_name = true) {
196199
std::string line;
197200

198201
if (!std::isalpha(in.peek()))
@@ -216,81 +219,71 @@ class stan_csv_reader {
216219
return true;
217220
}
218221

219-
static bool read_adaptation(std::istream& in, stan_csv_adaptation& adaptation,
220-
std::ostream* out) {
222+
static void read_adaptation(std::istream& in,
223+
stan_csv_adaptation& adaptation) {
221224
std::stringstream ss;
222225
std::string line;
223226
int lines = 0;
224-
225227
if (in.peek() != '#' || in.good() == false)
226-
return false;
227-
228+
return;
228229
while (in.peek() == '#') {
229230
std::getline(in, line);
230231
ss << line << std::endl;
231232
lines++;
232233
}
233234
ss.seekg(std::ios_base::beg);
235+
if (lines < 2)
236+
return;
234237

235-
if (lines < 4)
236-
return false;
237-
238-
char comment; // Buffer for comment indicator, #
238+
std::getline(ss, line); // comment adaptation terminated
239239

240-
// Skip first two lines
241-
std::getline(ss, line);
242-
243-
// Stepsize
244-
std::getline(ss, line, '=');
240+
// parse stepsize
241+
std::getline(ss, line, '='); // stepsize
245242
boost::trim(line);
246243
ss >> adaptation.step_size;
244+
if (lines == 2) // ADVI reports stepsize, no metric
245+
return;
247246

248-
// Metric parameters
249-
std::getline(ss, line);
250-
std::getline(ss, line);
251-
std::getline(ss, line);
247+
std::getline(ss, line); // consume end of stepsize line
248+
std::getline(ss, line); // comment elements of mass matrix
249+
std::getline(ss, line); // diagonal metric or row 1 of dense metric
252250

253251
int rows = lines - 3;
254252
int cols = std::count(line.begin(), line.end(), ',') + 1;
255253
adaptation.metric.resize(rows, cols);
254+
char comment; // Buffer for comment indicator, #
256255

256+
// parse metric, row by row, element by element
257257
for (int row = 0; row < rows; row++) {
258258
std::stringstream line_ss;
259259
line_ss.str(line);
260260
line_ss >> comment;
261-
262261
for (int col = 0; col < cols; col++) {
263262
std::string token;
264263
std::getline(line_ss, token, ',');
265264
boost::trim(token);
266265
std::stringstream(token) >> adaptation.metric(row, col);
267266
}
268-
std::getline(ss, line); // Read in next line
267+
std::getline(ss, line);
269268
}
270-
271-
if (ss.good())
272-
return false;
273-
else
274-
return true;
275269
}
276270

277271
static bool read_samples(std::istream& in, Eigen::MatrixXd& samples,
278-
stan_csv_timing& timing, std::ostream* out) {
272+
stan_csv_timing& timing) {
279273
std::stringstream ss;
280274
std::string line;
281275

282276
int rows = 0;
283277
int cols = -1;
284278

285279
if (in.peek() == '#' || in.good() == false)
286-
return false;
280+
return false; // need at least one data row
287281

288282
while (in.good()) {
289283
bool comment_line = (in.peek() == '#');
290284
bool empty_line = (in.peek() == '\n');
291285

292286
std::getline(in, line);
293-
294287
if (empty_line)
295288
continue;
296289
if (!line.length())
@@ -316,11 +309,10 @@ class stan_csv_reader {
316309
if (cols == -1) {
317310
cols = current_cols;
318311
} else if (cols != current_cols) {
319-
if (out)
320-
*out << "Error: expected " << cols << " columns, but found "
321-
<< current_cols << " instead for row " << rows + 1
322-
<< std::endl;
323-
return false;
312+
std::stringstream msg;
313+
msg << "Error: expected " << cols << " columns, but found "
314+
<< current_cols << " instead for row " << rows + 1;
315+
throw std::invalid_argument(msg.str());
324316
}
325317
rows++;
326318
}
@@ -348,36 +340,45 @@ class stan_csv_reader {
348340
/**
349341
* Parses the file.
350342
*
343+
* Throws exception if contents can't be parsed into header + data rows.
344+
*
345+
* Emits warning message
346+
*
351347
* @param[in] in input stream to parse
352348
* @param[out] out output stream to send messages
353349
*/
354350
static stan_csv parse(std::istream& in, std::ostream* out) {
355351
stan_csv data;
352+
std::string line;
356353

357-
if (!read_metadata(in, data.metadata, out)) {
358-
if (out)
359-
*out << "Warning: non-fatal error reading metadata" << std::endl;
354+
read_metadata(in, data.metadata);
355+
if (!read_header(in, data.header)) {
356+
throw std::invalid_argument("Error: no column names found in csv file");
360357
}
361358

362-
if (!read_header(in, data.header, out)) {
363-
if (out)
364-
*out << "Error: error reading header" << std::endl;
365-
throw std::invalid_argument("Error with header of input file in parse");
359+
// skip warmup draws, if any
360+
if (data.metadata.algorithm != "fixed_param" && data.metadata.num_warmup > 0
361+
&& data.metadata.save_warmup) {
362+
while (in.peek() != '#') {
363+
std::getline(in, line);
364+
}
366365
}
367366

368-
if (!read_adaptation(in, data.adaptation, out)) {
369-
if (out)
370-
*out << "Warning: non-fatal error reading adaptation data" << std::endl;
367+
if (data.metadata.algorithm != "fixed_param") {
368+
read_adaptation(in, data.adaptation);
371369
}
372370

373371
data.timing.warmup = 0;
374372
data.timing.sampling = 0;
375373

376-
if (!read_samples(in, data.samples, data.timing, out)) {
377-
if (out)
378-
*out << "Warning: non-fatal error reading samples" << std::endl;
374+
if (data.metadata.method == "variational") {
375+
std::getline(in, line); // discard variational estimate
379376
}
380377

378+
if (!read_samples(in, data.samples, data.timing)) {
379+
if (out)
380+
*out << "Unable to parse sample" << std::endl;
381+
}
381382
return data;
382383
}
383384
};

0 commit comments

Comments
 (0)