@@ -29,8 +29,6 @@ inline void prettify_stan_csv_name(std::string& variable) {
29
29
}
30
30
}
31
31
32
- // FIXME: should consolidate with the options from
33
- // the command line in stan::lang
34
32
struct stan_csv_metadata {
35
33
int stan_version_major;
36
34
int stan_version_minor;
@@ -47,6 +45,7 @@ struct stan_csv_metadata {
47
45
bool save_warmup;
48
46
size_t thin;
49
47
bool append_samples;
48
+ std::string method;
50
49
std::string algorithm;
51
50
std::string engine;
52
51
int max_depth;
@@ -64,8 +63,9 @@ struct stan_csv_metadata {
64
63
num_samples(0 ),
65
64
num_warmup(0 ),
66
65
save_warmup(false ),
67
- thin(0 ),
66
+ thin(1 ),
68
67
append_samples(false ),
68
+ method(" " ),
69
69
algorithm(" " ),
70
70
engine(" " ),
71
71
max_depth(10 ) {}
@@ -101,13 +101,12 @@ class stan_csv_reader {
101
101
stan_csv_reader () {}
102
102
~stan_csv_reader () {}
103
103
104
- static bool read_metadata (std::istream& in, stan_csv_metadata& metadata,
105
- std::ostream* out) {
104
+ static void read_metadata (std::istream& in, stan_csv_metadata& metadata) {
106
105
std::stringstream ss;
107
106
std::string line;
108
107
109
108
if (in.peek () != ' #' )
110
- return false ;
109
+ return ;
111
110
while (in.peek () == ' #' ) {
112
111
std::getline (in, line);
113
112
ss << line << ' \n ' ;
@@ -161,9 +160,15 @@ class stan_csv_reader {
161
160
metadata.model = value;
162
161
} else if (name.compare (" num_samples" ) == 0 ) {
163
162
std::stringstream (value) >> metadata.num_samples ;
163
+ } else if (name.compare (" output_samples" ) == 0 ) { // ADVI config name
164
+ std::stringstream (value) >> metadata.num_samples ;
164
165
} else if (name.compare (" num_warmup" ) == 0 ) {
165
166
std::stringstream (value) >> metadata.num_warmup ;
166
167
} else if (name.compare (" save_warmup" ) == 0 ) {
168
+ // cmdstan args can be "true" and "false", was "1", "0"
169
+ if (value.compare (" true" ) == 0 ) {
170
+ value = " 1" ;
171
+ }
167
172
std::stringstream (value) >> metadata.save_warmup ;
168
173
} else if (name.compare (" thin" ) == 0 ) {
169
174
std::stringstream (value) >> metadata.thin ;
@@ -177,6 +182,8 @@ class stan_csv_reader {
177
182
metadata.random_seed = false ;
178
183
} else if (name.compare (" append_samples" ) == 0 ) {
179
184
std::stringstream (value) >> metadata.append_samples ;
185
+ } else if (name.compare (" method" ) == 0 ) {
186
+ metadata.method = value;
180
187
} else if (name.compare (" algorithm" ) == 0 ) {
181
188
metadata.algorithm = value;
182
189
} else if (name.compare (" engine" ) == 0 ) {
@@ -185,14 +192,10 @@ class stan_csv_reader {
185
192
std::stringstream (value) >> metadata.max_depth ;
186
193
}
187
194
}
188
- if (ss.good () == true )
189
- return false ;
190
-
191
- return true ;
192
195
} // read_metadata
193
196
194
197
static bool read_header (std::istream& in, std::vector<std::string>& header,
195
- std::ostream* out, bool prettify_name = true ) {
198
+ bool prettify_name = true ) {
196
199
std::string line;
197
200
198
201
if (!std::isalpha (in.peek ()))
@@ -216,81 +219,71 @@ class stan_csv_reader {
216
219
return true ;
217
220
}
218
221
219
- static bool read_adaptation (std::istream& in, stan_csv_adaptation& adaptation ,
220
- std::ostream* out ) {
222
+ static void read_adaptation (std::istream& in,
223
+ stan_csv_adaptation& adaptation ) {
221
224
std::stringstream ss;
222
225
std::string line;
223
226
int lines = 0 ;
224
-
225
227
if (in.peek () != ' #' || in.good () == false )
226
- return false ;
227
-
228
+ return ;
228
229
while (in.peek () == ' #' ) {
229
230
std::getline (in, line);
230
231
ss << line << std::endl;
231
232
lines++;
232
233
}
233
234
ss.seekg (std::ios_base::beg);
235
+ if (lines < 2 )
236
+ return ;
234
237
235
- if (lines < 4 )
236
- return false ;
237
-
238
- char comment; // Buffer for comment indicator, #
238
+ std::getline (ss, line); // comment adaptation terminated
239
239
240
- // Skip first two lines
241
- std::getline (ss, line);
242
-
243
- // Stepsize
244
- std::getline (ss, line, ' =' );
240
+ // parse stepsize
241
+ std::getline (ss, line, ' =' ); // stepsize
245
242
boost::trim (line);
246
243
ss >> adaptation.step_size ;
244
+ if (lines == 2 ) // ADVI reports stepsize, no metric
245
+ return ;
247
246
248
- // Metric parameters
249
- std::getline (ss, line);
250
- std::getline (ss, line);
251
- std::getline (ss, line);
247
+ std::getline (ss, line); // consume end of stepsize line
248
+ std::getline (ss, line); // comment elements of mass matrix
249
+ std::getline (ss, line); // diagonal metric or row 1 of dense metric
252
250
253
251
int rows = lines - 3 ;
254
252
int cols = std::count (line.begin (), line.end (), ' ,' ) + 1 ;
255
253
adaptation.metric .resize (rows, cols);
254
+ char comment; // Buffer for comment indicator, #
256
255
256
+ // parse metric, row by row, element by element
257
257
for (int row = 0 ; row < rows; row++) {
258
258
std::stringstream line_ss;
259
259
line_ss.str (line);
260
260
line_ss >> comment;
261
-
262
261
for (int col = 0 ; col < cols; col++) {
263
262
std::string token;
264
263
std::getline (line_ss, token, ' ,' );
265
264
boost::trim (token);
266
265
std::stringstream (token) >> adaptation.metric (row, col);
267
266
}
268
- std::getline (ss, line); // Read in next line
267
+ std::getline (ss, line);
269
268
}
270
-
271
- if (ss.good ())
272
- return false ;
273
- else
274
- return true ;
275
269
}
276
270
277
271
static bool read_samples (std::istream& in, Eigen::MatrixXd& samples,
278
- stan_csv_timing& timing, std::ostream* out ) {
272
+ stan_csv_timing& timing) {
279
273
std::stringstream ss;
280
274
std::string line;
281
275
282
276
int rows = 0 ;
283
277
int cols = -1 ;
284
278
285
279
if (in.peek () == ' #' || in.good () == false )
286
- return false ;
280
+ return false ; // need at least one data row
287
281
288
282
while (in.good ()) {
289
283
bool comment_line = (in.peek () == ' #' );
290
284
bool empty_line = (in.peek () == ' \n ' );
291
285
292
286
std::getline (in, line);
293
-
294
287
if (empty_line)
295
288
continue ;
296
289
if (!line.length ())
@@ -316,11 +309,10 @@ class stan_csv_reader {
316
309
if (cols == -1 ) {
317
310
cols = current_cols;
318
311
} else if (cols != current_cols) {
319
- if (out)
320
- *out << " Error: expected " << cols << " columns, but found "
321
- << current_cols << " instead for row " << rows + 1
322
- << std::endl;
323
- return false ;
312
+ std::stringstream msg;
313
+ msg << " Error: expected " << cols << " columns, but found "
314
+ << current_cols << " instead for row " << rows + 1 ;
315
+ throw std::invalid_argument (msg.str ());
324
316
}
325
317
rows++;
326
318
}
@@ -348,36 +340,45 @@ class stan_csv_reader {
348
340
/* *
349
341
* Parses the file.
350
342
*
343
+ * Throws exception if contents can't be parsed into header + data rows.
344
+ *
345
+ * Emits warning message
346
+ *
351
347
* @param[in] in input stream to parse
352
348
* @param[out] out output stream to send messages
353
349
*/
354
350
static stan_csv parse (std::istream& in, std::ostream* out) {
355
351
stan_csv data;
352
+ std::string line;
356
353
357
- if (! read_metadata (in, data.metadata , out)) {
358
- if (out)
359
- *out << " Warning: non-fatal error reading metadata " << std::endl ;
354
+ read_metadata (in, data.metadata );
355
+ if (! read_header (in, data. header )) {
356
+ throw std::invalid_argument ( " Error: no column names found in csv file " ) ;
360
357
}
361
358
362
- if (!read_header (in, data.header , out)) {
363
- if (out)
364
- *out << " Error: error reading header" << std::endl;
365
- throw std::invalid_argument (" Error with header of input file in parse" );
359
+ // skip warmup draws, if any
360
+ if (data.metadata .algorithm != " fixed_param" && data.metadata .num_warmup > 0
361
+ && data.metadata .save_warmup ) {
362
+ while (in.peek () != ' #' ) {
363
+ std::getline (in, line);
364
+ }
366
365
}
367
366
368
- if (!read_adaptation (in, data.adaptation , out)) {
369
- if (out)
370
- *out << " Warning: non-fatal error reading adaptation data" << std::endl;
367
+ if (data.metadata .algorithm != " fixed_param" ) {
368
+ read_adaptation (in, data.adaptation );
371
369
}
372
370
373
371
data.timing .warmup = 0 ;
374
372
data.timing .sampling = 0 ;
375
373
376
- if (!read_samples (in, data.samples , data.timing , out)) {
377
- if (out)
378
- *out << " Warning: non-fatal error reading samples" << std::endl;
374
+ if (data.metadata .method == " variational" ) {
375
+ std::getline (in, line); // discard variational estimate
379
376
}
380
377
378
+ if (!read_samples (in, data.samples , data.timing )) {
379
+ if (out)
380
+ *out << " Unable to parse sample" << std::endl;
381
+ }
381
382
return data;
382
383
}
383
384
};
0 commit comments