5 SpJacFun_config::SpJacFun_config() : compress(false), index_remap(true) {}
11 vmatrix
matmul(
const vmatrix &x,
const vmatrix &y) {
12 vmatrix z(x.rows(), y.cols());
13 Map<vmatrix> zm(&z(0), z.rows(), z.cols());
14 matmul<false, false, false, false>(x, y, zm);
18 dmatrix
matmul(
const dmatrix &x,
const dmatrix &y) {
return x * y; }
21 #include "checkpoint.hpp" 24 bool ParametersChanged::operator()(
const std::vector<Scalar> &x) {
25 bool change = (x != x_prev);
33 #include "code_generator.hpp" 36 void searchReplace(std::string &str,
const std::string &oldStr,
37 const std::string &newStr) {
38 std::string::size_type pos = 0u;
39 while ((pos = str.find(oldStr, pos)) != std::string::npos) {
40 str.replace(pos, oldStr.length(), newStr);
41 pos += newStr.length();
45 std::string code_config::float_ptr() {
return float_str + (gpu ?
"**" :
"*"); }
47 std::string code_config::void_str() {
48 return (gpu ?
"__device__ void" :
"extern \"C\" void");
51 void code_config::init_code() {
53 *cout << indent <<
"int idx = threadIdx.x;" << std::endl;
57 void code_config::write_header_comment() {
58 if (header_comment.length() > 0) *cout << header_comment << std::endl;
61 code_config::code_config()
65 header_comment(
"// Autogenerated - do not edit by hand !"),
66 float_str(xstringify(TMBAD_SCALAR_TYPE)),
69 void write_common(std::ostringstream &buffer, code_config cfg,
size_t node) {
70 std::ostream &cout = *cfg.cout;
74 std::string indent = cfg.indent;
76 cout << indent <<
"asm(\"// Node: " << node <<
"\");" << endl;
77 bool empty_buffer = (buffer.tellp() == 0);
79 std::string str = buffer.str();
81 std::string pattern =
"]";
82 std::string replace =
"][idx]";
83 searchReplace(str, pattern, replace);
85 searchReplace(str,
";v",
"; v");
86 searchReplace(str,
";d",
"; d");
87 cout << indent << str << endl;
91 void write_forward(
global &glob, code_config cfg) {
95 std::ostream &cout = *cfg.cout;
96 cfg.write_header_comment();
97 cout << cfg.void_str() <<
" forward(" << cfg.float_ptr() <<
" v) {" << endl;
100 for (
size_t i = 0; i < glob.
opstack.size(); i++) {
101 std::ostringstream buffer;
102 Writer::cout = &buffer;
103 glob.
opstack[i]->forward(args);
104 write_common(buffer, cfg, i);
110 void write_reverse(
global &glob, code_config cfg) {
114 std::ostream &cout = *cfg.cout;
115 cfg.write_header_comment();
116 cout << cfg.void_str() <<
" reverse(" << cfg.float_ptr() <<
" v, " 117 << cfg.float_ptr() <<
" d) {" << endl;
120 for (
size_t i = glob.
opstack.size(); i > 0;) {
123 std::ostringstream buffer;
124 Writer::cout = &buffer;
125 glob.
opstack[i]->reverse(args);
126 write_common(buffer, cfg, i);
131 void write_all(
global glob, code_config cfg) {
135 std::ostream &cout = *cfg.cout;
136 cout <<
"#include \"global.hpp\"" << endl;
137 cout <<
"#include \"ad_blas.hpp\"" << endl;
138 write_forward(glob, cfg);
139 write_reverse(glob, cfg);
140 cout <<
"int main() {}" << endl;
145 #include "compile.hpp" 148 void compile(
global &glob, code_config cfg) {
150 cfg.asm_comments =
false;
152 file.open(
"tmp.cpp");
155 *cfg.cout <<
"#include <cmath>" << std::endl;
157 <<
"template<class T>T sign(const T &x) { return (x > 0) - (x < 0); }" 160 write_forward(glob, cfg);
162 write_reverse(glob, cfg);
164 int out = system(
"g++ -O3 -g tmp.cpp -o tmp.so -shared -fPIC");
168 void *handle = dlopen(
"./tmp.so", RTLD_NOW);
169 if (handle != NULL) {
170 Rcout <<
"Loading compiled code!" << std::endl;
172 reinterpret_cast<void (*)(Scalar *)
>(dlsym(handle,
"forward"));
174 dlsym(handle,
"reverse"));
180 #include "compression.hpp" 183 std::ostream &operator<<(std::ostream &os,
const period &x) {
184 os <<
"begin: " << x.
begin;
185 os <<
" size: " << x.
size;
186 os <<
" rep: " << x.
rep;
191 size_t max_period_size) {
192 typedef std::ptrdiff_t ptrdiff_t;
195 size_t offset = glob->subgraph_ptr[p.
begin].first;
198 for (
size_t i = 0; i < p.
size; i++) {
204 matrix_view<Index> x(&(glob->
inputs[offset]), nrow, ncol);
206 std::vector<bool> marks(ncol - 1,
false);
208 for (
size_t i = 0; i < nrow; i++) {
209 std::vector<period> pd =
213 for (
size_t j = 0; j < pd.size(); j++) {
214 if (pd[j].begin > 0) {
215 marks[pd[j].begin - 1] =
true;
217 size_t end = pd[j].begin + pd[j].size * pd[j].rep;
218 if (end < marks.size()) marks[end] =
true;
222 std::vector<period> ans;
225 for (
size_t j = 0; j < marks.size(); j++) {
239 size_t compressed_input::input_size()
const {
return n; }
241 void compressed_input::update_increment_pattern()
const {
242 for (
size_t i = 0; i < (size_t)np; i++)
243 increment_pattern[which_periodic[i]] =
244 period_data[period_offsets[i] + counter % period_sizes[i]];
247 void compressed_input::increment(
Args<> &args)
const {
249 update_increment_pattern();
252 for (
size_t i = 0; i < n; i++) inputs[i] += increment_pattern[i];
256 void compressed_input::decrement(
Args<> &args)
const {
257 args.
ptr.first = input_size();
258 for (
size_t i = 0; i < n; i++) inputs[i] -= increment_pattern[i];
261 update_increment_pattern();
265 void compressed_input::forward_init(
Args<> &args)
const {
267 inputs.resize(input_size());
268 for (
size_t i = 0; i < inputs.size(); i++) inputs[i] = args.
input(i);
269 args.
inputs = inputs.data();
273 void compressed_input::reverse_init(
Args<> &args) {
274 inputs.resize(input_size());
275 for (
size_t i = 0; i < inputs.size(); i++)
276 inputs[i] = args.
input(i) + input_diff[i];
278 args.
inputs = inputs.data();
280 args.
ptr.second += m * nrep;
282 update_increment_pattern();
283 args.
ptr.first = input_size();
286 void compressed_input::dependencies_intervals(
Args<> &args,
287 std::vector<Index> &lower,
288 std::vector<Index> &upper)
const {
292 for (
size_t i = 0; i < nrep; i++) {
293 for (
size_t j = 0; j < inputs.size(); j++) {
294 if (inputs[j] < lower[j]) lower[j] = inputs[j];
295 if (inputs[j] > upper[j]) upper[j] = inputs[j];
301 bool compressed_input::test_period(std::vector<ptrdiff_t> &x,
size_t p) {
302 for (
size_t j = 0; j < x.size(); j++) {
303 if (x[j] != x[j % p])
return false;
308 size_t compressed_input::find_shortest(std::vector<ptrdiff_t> &x) {
309 for (
size_t p = 1; p < max_period_size; p++) {
310 if (test_period(x, p))
return p;
315 compressed_input::compressed_input() {}
317 compressed_input::compressed_input(std::vector<Index> &x,
size_t offset,
318 size_t nrow,
size_t m,
size_t ncol,
319 size_t max_period_size)
320 : n(nrow), m(m), nrep(ncol), counter(0), max_period_size(max_period_size) {
321 matrix_view<Index> xm(&x[offset], nrow, ncol);
323 for (
size_t i = 0; i < nrow; i++) {
324 std::vector<ptrdiff_t> rd = xm.row_diff<ptrdiff_t>(i);
326 size_t p = find_shortest(rd);
328 increment_pattern.push_back(rd[0]);
330 which_periodic.push_back(i);
331 period_sizes.push_back(p);
333 size_t pos = std::search(period_data.begin(), period_data.end(),
334 rd.begin(), rd.begin() + p) -
336 if (pos < period_data.size()) {
337 period_offsets.push_back(pos);
339 period_offsets.push_back(period_data.size());
340 period_data.insert(period_data.end(), rd.begin(), rd.begin() + p);
345 np = which_periodic.size();
347 input_diff.resize(n, 0);
350 for (
size_t i = 0; i < nrep; i++) {
357 size_t max_period_size) {
358 opstack.resize(p.
size);
360 for (
size_t i = 0; i < p.
size; i++) {
362 n += opstack[i]->input_size();
363 m += opstack[i]->output_size();
365 ci = compressed_input(glob->
inputs, ptr.first, n, m, p.
rep, max_period_size);
368 StackOp::StackOp(
const StackOp &x) : opstack(x.opstack), ci(x.ci) {}
371 std::vector<const char *> tmp(opstack.size());
372 for (
size_t i = 0; i < opstack.size(); i++) tmp[i] = opstack[i]->op_name();
373 Rcout << cfg.prefix <<
" opstack = " << tmp <<
"\n";
375 Rcout << cfg.prefix <<
" " 377 <<
" = " << ci.nrep <<
"\n";
379 Rcout << cfg.prefix <<
" " 380 <<
"increment_pattern" 381 <<
" = " << ci.increment_pattern <<
"\n";
383 if (ci.which_periodic.size() > 0) {
384 Rcout << cfg.prefix <<
" " 386 <<
" = " << ci.which_periodic <<
"\n";
388 Rcout << cfg.prefix <<
" " 390 <<
" = " << ci.period_sizes <<
"\n";
392 Rcout << cfg.prefix <<
" " 394 <<
" = " << ci.period_offsets <<
"\n";
396 Rcout << cfg.prefix <<
" " 398 <<
" = " << ci.period_data <<
"\n";
405 Index StackOp::input_size()
const {
return ci.n; }
407 Index StackOp::output_size()
const {
return ci.m * ci.nrep; }
410 size_t n = ci.n, m = ci.m, nrep = ci.nrep;
411 std::vector<Index> inputs(n);
412 for (
size_t i = 0; i < (size_t)n; i++) inputs[i] = args.
input(i);
413 std::vector<Index> outputs(m);
414 for (
size_t i = 0; i < (size_t)m; i++) outputs[i] = args.
output(i);
416 size_t np = ci.which_periodic.size();
417 size_t sp = ci.period_data.size();
418 w <<
"for (int count = 0, ";
420 w <<
"i[" << n <<
"]=" << inputs <<
", " 421 <<
"ip[" << n <<
"]=" << ci.increment_pattern <<
", ";
424 w <<
"wp[" << np <<
"]=" << ci.which_periodic <<
", " 425 <<
"ps[" << np <<
"]=" << ci.period_sizes <<
", " 426 <<
"po[" << np <<
"]=" << ci.period_offsets <<
", " 427 <<
"pd[" << sp <<
"]=" << ci.period_data <<
", ";
429 w <<
"o[" << m <<
"]=" << outputs <<
"; " 430 <<
"count < " << nrep <<
"; count++) {\n";
434 args_cpy.set_indirect();
435 for (
size_t k = 0; k < opstack.size(); k++) {
436 opstack[k]->forward_incr(args_cpy);
442 for (
size_t k = 0; k < np; k++)
443 w <<
"ip[wp[" << k <<
"]] = pd[po[" << k <<
"] + count % ps[" << k
449 for (
size_t k = 0; k < n; k++) w <<
"i[" << k <<
"] += ip[" << k <<
"]; ";
453 for (
size_t k = 0; k < m; k++) w <<
"o[" << k <<
"] += " << m <<
"; ";
461 size_t n = ci.n, m = ci.m, nrep = ci.nrep;
462 std::vector<ptrdiff_t> inputs(input_size());
463 for (
size_t i = 0; i < inputs.size(); i++) {
465 if (-ci.input_diff[i] < ci.input_diff[i]) {
466 tmp = -((ptrdiff_t)-ci.input_diff[i]);
468 tmp = ci.input_diff[i];
470 inputs[i] = args.
input(i) + tmp;
472 std::vector<Index> outputs(ci.m);
473 for (
size_t i = 0; i < (size_t)ci.m; i++)
474 outputs[i] = args.
output(i) + ci.m * ci.nrep;
476 size_t np = ci.which_periodic.size();
477 size_t sp = ci.period_data.size();
478 w <<
"for (int count = " << nrep <<
", ";
480 w <<
"i[" << n <<
"]=" << inputs <<
", " 481 <<
"ip[" << n <<
"]=" << ci.increment_pattern <<
", ";
484 w <<
"wp[" << np <<
"]=" << ci.which_periodic <<
", " 485 <<
"ps[" << np <<
"]=" << ci.period_sizes <<
", " 486 <<
"po[" << np <<
"]=" << ci.period_offsets <<
", " 487 <<
"pd[" << sp <<
"]=" << ci.period_data <<
", ";
489 w <<
"o[" << m <<
"]=" << outputs <<
"; " 490 <<
"count > 0 ; ) {\n";
496 for (
size_t k = 0; k < np; k++)
497 w <<
"ip[wp[" << k <<
"]] = pd[po[" << k <<
"] + count % ps[" << k
503 for (
size_t k = 0; k < n; k++) w <<
"i[" << k <<
"] -= ip[" << k <<
"]; ";
507 for (
size_t k = 0; k < m; k++) w <<
"o[" << k <<
"] -= " << m <<
"; ";
513 args_cpy.set_indirect();
514 args_cpy.
ptr.first = ci.n;
515 args_cpy.
ptr.second = ci.m;
516 for (
size_t k = opstack.size(); k > 0;) {
518 opstack[k]->reverse_decr(args_cpy);
526 void StackOp::dependencies(
Args<> args, Dependencies &dep)
const {
527 std::vector<Index> lower;
528 std::vector<Index> upper;
529 ci.dependencies_intervals(args, lower, upper);
530 for (
size_t i = 0; i < lower.size(); i++) {
531 dep.add_interval(lower[i], upper[i]);
535 const char *StackOp::op_name() {
return "StackOp"; }
545 std::vector<Index> remap = radix::first_occurance<Index>(h);
550 for (
size_t i = 0; i < glob.
opstack.size(); i++) {
552 glob.
opstack[i]->dependencies(args, dep);
554 Index var = args.
ptr.second;
555 toposort_remap<Index> fb(remap, var);
560 std::vector<Index> ord = radix::order<Index>(remap);
561 std::vector<Index> v2o = glob.
var2op();
562 glob.subgraph_seq =
subset(v2o, ord);
568 std::vector<Index> remap(glob.
values.size(), Index(-1));
570 for (
size_t i = 0; i < glob.
opstack.size(); i++) {
572 glob.
opstack[i]->dependencies(args, dep);
574 Index var = args.ptr.second;
575 temporaries_remap<Index> fb(remap, var);
577 glob.
opstack[i]->increment(args.ptr);
580 for (
size_t i = remap.size(); i > 0;) {
582 if (remap[i] == Index(-1))
585 remap[i] = remap[remap[i]];
588 std::vector<Index> ord = radix::order<Index>(remap);
589 std::vector<Index> v2o = glob.
var2op();
590 glob.subgraph_seq =
subset(v2o, ord);
596 std::vector<bool> visited(glob.
opstack.size(),
false);
597 std::vector<Index> v2o = glob.
var2op();
598 std::vector<Index> stack;
599 std::vector<Index> result;
602 for (
size_t k = 0; k < glob.
dep_index.size(); k++) {
604 Index i = v2o[dep_var];
608 while (stack.size() > 0) {
609 Index i = stack.back();
610 args.
ptr = glob.subgraph_ptr[i];
612 glob.
opstack[i]->dependencies(args, dep);
613 dfs_add_to_stack<Index> add_to_stack(stack, visited, v2o);
614 size_t before = stack.size();
615 dep.apply(add_to_stack);
616 size_t after = stack.size();
617 if (before == after) {
624 glob.subgraph_seq = result;
630 void compress(
global &glob,
size_t max_period_size) {
631 size_t min_period_rep = TMBAD_MIN_PERIOD_REP;
634 std::vector<period> periods = p.find_all();
636 std::vector<period> periods_expand;
637 for (
size_t i = 0; i < periods.size(); i++) {
638 std::vector<period> tmp =
split_period(&glob, periods[i], max_period_size);
640 if (tmp.size() > 10) {
642 tmp.push_back(periods[i]);
645 for (
size_t j = 0; j < tmp.size(); j++) {
646 if (tmp[j].rep > 1) periods_expand.push_back(tmp[j]);
650 std::swap(periods, periods_expand);
654 for (
size_t i = 0; i < periods.size(); i++) {
656 TMBAD_ASSERT(p.
rep >= 1);
657 while (k < p.
begin) {
658 glob.
opstack[k]->increment(ptr);
663 get_glob()->getOperator<StackOp>(&glob, p, ptr, max_period_size);
665 for (
size_t j = 0; j < p.size * p.rep; j++) {
666 ninp += glob.opstack[p.begin + j]->input_size();
667 glob.opstack[p.begin + j]->deallocate();
668 glob.opstack[p.begin + j] = null_op;
670 glob.opstack[p.begin] = pOp;
672 glob.opstack[p.begin + 1] =
676 std::vector<bool> marks(glob.
values.size(),
true);
682 #include "global.hpp" 685 global *global_ptr_data[TMBAD_MAX_NUM_THREADS] = {NULL};
686 global **global_ptr = global_ptr_data;
687 std::ostream *Writer::cout = 0;
688 bool global::fuse = 0;
692 Dependencies::Dependencies() {}
694 void Dependencies::clear() {
699 void Dependencies::add_interval(Index a, Index b) {
700 I.push_back(std::pair<Index, Index>(a, b));
703 void Dependencies::add_segment(Index start, Index size) {
704 if (size > 0) add_interval(start, start + size - 1);
707 void Dependencies::monotone_transform_inplace(
const std::vector<Index> &x) {
708 for (
size_t i = 0; i < this->size(); i++) (*
this)[i] = x[(*this)[i]];
709 for (
size_t i = 0; i < I.size(); i++) {
710 I[i].first = x[I[i].first];
711 I[i].second = x[I[i].second];
715 bool Dependencies::any(
const std::vector<bool> &x)
const {
716 for (
size_t i = 0; i < this->size(); i++)
717 if (x[(*
this)[i]])
return true;
718 for (
size_t i = 0; i < I.size(); i++) {
719 for (Index j = I[i].first; j <= I[i].second; j++) {
720 if (x[j])
return true;
726 std::string tostr(
const Index &x) {
727 std::ostringstream strs;
732 std::string tostr(
const Scalar &x) {
733 std::ostringstream strs;
738 Writer::Writer(std::string str) : std::string(str) {}
740 Writer::Writer(Scalar x) : std::string(tostr(x)) {}
744 std::string Writer::p(std::string x) {
return "(" + x +
")"; }
746 Writer Writer::operator+(
const Writer &other) {
747 return p(*
this +
" + " + other);
750 Writer Writer::operator-(
const Writer &other) {
751 return p(*
this +
" - " + other);
754 Writer Writer::operator-() {
return " - " + *
this; }
758 Writer Writer::operator/(
const Writer &other) {
return *
this +
" / " + other; }
761 return *
this +
"*" + tostr(other);
764 Writer Writer::operator+(
const Scalar &other) {
765 return p(*
this +
"+" + tostr(other));
768 void Writer::operator=(
const Writer &other) {
769 *cout << *
this +
" = " + other <<
";";
772 void Writer::operator+=(
const Writer &other) {
773 *cout << *
this +
" += " + other <<
";";
776 void Writer::operator-=(
const Writer &other) {
777 *cout << *
this +
" -= " + other <<
";";
780 void Writer::operator*=(
const Writer &other) {
781 *cout << *
this +
" *= " + other <<
";";
784 void Writer::operator/=(
const Writer &other) {
785 *cout << *
this +
" /= " + other <<
";";
788 Position::Position(Index node, Index first, Index second)
789 : node(node), ptr(first, second) {}
791 Position::Position() : node(0), ptr(0, 0) {}
793 bool Position::operator<(
const Position &other)
const {
794 return this->node < other.node;
799 size_t graph::num_neighbors(Index node) {
return p[node + 1] - p[node]; }
801 Index *graph::neighbors(Index node) {
return &(j[p[node]]); }
803 bool graph::empty() {
return p.size() == 0; }
805 size_t graph::num_nodes() {
return (empty() ? 0 : p.size() - 1); }
807 void graph::print() {
808 for (
size_t node = 0; node < num_nodes(); node++) {
809 Rcout << node <<
": ";
810 for (
size_t i = 0; i < num_neighbors(node); i++) {
811 Rcout <<
" " << neighbors(node)[i];
817 std::vector<Index> graph::rowcounts() {
818 std::vector<Index> ans(num_nodes());
819 for (
size_t i = 0; i < ans.size(); i++) ans[i] = num_neighbors(i);
823 std::vector<Index> graph::colcounts() {
824 std::vector<Index> ans(num_nodes());
825 for (
size_t i = 0; i < j.size(); i++) ans[j[i]]++;
829 void graph::bfs(
const std::vector<Index> &start, std::vector<bool> &visited,
830 std::vector<Index> &result) {
831 for (
size_t i = 0; i < start.size(); i++) {
832 Index node = start[i];
833 for (
size_t j_ = 0; j_ < num_neighbors(node); j_++) {
834 Index k = neighbors(node)[j_];
843 void graph::search(std::vector<Index> &start,
bool sort_input,
845 if (mark.size() == 0) mark.resize(num_nodes(),
false);
847 search(start, mark, sort_input, sort_output);
849 for (
size_t i = 0; i < start.size(); i++) mark[start[i]] =
false;
852 void graph::search(std::vector<Index> &start, std::vector<bool> &visited,
853 bool sort_input,
bool sort_output) {
856 for (
size_t i = 0; i < start.size(); i++) visited[start[i]] =
true;
858 bfs(start, visited, start);
863 std::vector<Index> graph::boundary(
const std::vector<Index> &subgraph) {
864 if (mark.size() == 0) mark.resize(num_nodes(),
false);
866 std::vector<Index> boundary;
868 for (
size_t i = 0; i < subgraph.size(); i++) mark[subgraph[i]] =
true;
870 bfs(subgraph, mark, boundary);
872 for (
size_t i = 0; i < subgraph.size(); i++) mark[subgraph[i]] =
false;
873 for (
size_t i = 0; i < boundary.size(); i++) mark[boundary[i]] =
false;
878 graph::graph(
size_t num_nodes,
const std::vector<IndexPair> &edges) {
879 std::vector<IndexPair>::const_iterator it;
880 std::vector<Index> row_counts(num_nodes, 0);
881 for (it = edges.begin(); it != edges.end(); it++) {
882 row_counts[it->first]++;
885 p.resize(num_nodes + 1);
887 for (
size_t i = 0; i < num_nodes; i++) {
888 p[i + 1] = p[i] + row_counts[i];
891 std::vector<Index> k(p);
892 j.resize(edges.size());
893 for (it = edges.begin(); it != edges.end(); it++) {
894 j[k[it->first]++] = it->second;
898 op_info::op_info() : code(0) {
899 static_assert(
sizeof(IntRep) * 8 >= op_flag_count,
900 "'IntRep' not wide enough!");
903 op_info::op_info(op_flag f) : code(1 << f) {}
905 bool op_info::test(
op_flag f)
const {
return code & 1 << f; }
917 global::operation_stack::operation_stack() {}
920 (*this).copy_from(other);
931 if (
this != &other) {
933 (*this).copy_from(other);
938 global::operation_stack::~operation_stack() { (*this).clear(); }
940 void global::operation_stack::clear() {
941 if (any.test(op_info::dynamic)) {
942 for (
size_t i = 0; i < (*this).size(); i++) (*
this)[i]->deallocate();
947 void global::operation_stack::copy_from(
const operation_stack &other) {
948 if (other.
any.
test(op_info::dynamic)) {
949 for (
size_t i = 0; i < other.size(); i++) Base::push_back(other[i]->copy());
951 Base::operator=(other);
953 this->any = other.
any;
957 : forward_compiled(NULL),
958 reverse_compiled(NULL),
962 void global::clear() {
968 subgraph_ptr.resize(0);
969 subgraph_seq.resize(0);
973 void global::shrink_to_fit(
double tol) {
974 std::vector<Scalar>().swap(derivs);
975 std::vector<IndexPair>().swap(subgraph_ptr);
976 if (values.size() < tol * values.capacity())
977 std::vector<Scalar>(values).swap(values);
978 if (inputs.size() < tol * inputs.capacity())
979 std::vector<Index>(inputs).swap(inputs);
980 if (opstack.size() < tol * opstack.capacity())
981 std::vector<OperatorPure *>(opstack).swap(opstack);
984 void global::clear_deriv(Position start) {
985 derivs.resize(values.size());
986 std::fill(derivs.begin() + start.ptr.second, derivs.end(), 0);
989 Scalar &global::value_inv(Index i) {
return values[inv_index[i]]; }
991 Scalar &global::deriv_inv(Index i) {
return derivs[inv_index[i]]; }
993 Scalar &global::value_dep(Index i) {
return values[dep_index[i]]; }
995 Scalar &global::deriv_dep(Index i) {
return derivs[dep_index[i]]; }
997 Position global::begin() {
return Position(0, 0, 0); }
999 Position global::end() {
1000 return Position(opstack.size(), inputs.size(), values.size());
1003 CONSTEXPR
bool global::no_filter::operator[](
size_t i)
const {
return true; }
1005 void global::forward(Position start) {
1006 if (forward_compiled != NULL) {
1007 forward_compiled(values.data());
1011 args.
ptr = start.ptr;
1012 forward_loop(args, start.node);
1015 void global::reverse(Position start) {
1016 if (reverse_compiled != NULL) {
1017 reverse_compiled(values.data(), derivs.data());
1021 reverse_loop(args, start.node);
1024 void global::forward_sub() {
1026 forward_loop_subgraph(args);
1029 void global::reverse_sub() {
1031 reverse_loop_subgraph(args);
1034 void global::forward(std::vector<bool> &marks) {
1040 void global::reverse(std::vector<bool> &marks) {
1046 void global::forward_sub(std::vector<bool> &marks,
1047 const std::vector<bool> &node_filter) {
1050 if (node_filter.size() == 0)
1051 forward_loop_subgraph(args);
1053 forward_loop(args, 0, node_filter);
1056 void global::reverse_sub(std::vector<bool> &marks,
1057 const std::vector<bool> &node_filter) {
1060 if (node_filter.size() == 0)
1061 reverse_loop_subgraph(args);
1063 reverse_loop(args, 0, node_filter);
1066 void global::forward_dense(std::vector<bool> &marks) {
1069 for (
size_t i = 0; i < opstack.size(); i++) {
1070 opstack[i]->forward_incr_mark_dense(args);
1078 for (
size_t i = 0; i < opstack.size(); i++) {
1079 if (opstack[i]->info().test(op_info::updating)) {
1081 opstack[i]->dependencies(args, dep);
1083 for (
size_t i = 0; i < dep.I.size(); i++) {
1084 Index a = dep.I[i].first;
1085 Index b = dep.I[i].second;
1086 marked_intervals.
insert(a, b);
1089 opstack[i]->increment(args.
ptr);
1091 return marked_intervals;
1098 subgraph_cache_ptr();
1099 for (
size_t j = 0; j < subgraph_seq.size(); j++) {
1100 Index i = subgraph_seq[j];
1101 args.
ptr = subgraph_ptr[i];
1102 if (opstack[i]->info().test(op_info::updating)) {
1104 opstack[i]->dependencies(args, dep);
1106 for (
size_t i = 0; i < dep.I.size(); i++) {
1107 Index a = dep.I[i].first;
1108 Index b = dep.I[i].second;
1109 marked_intervals.
insert(a, b);
1113 return marked_intervals;
1116 Replay &global::replay::value_inv(Index i) {
return values[orig.inv_index[i]]; }
1118 Replay &global::replay::deriv_inv(Index i) {
return derivs[orig.inv_index[i]]; }
1120 Replay &global::replay::value_dep(Index i) {
return values[orig.dep_index[i]]; }
1122 Replay &global::replay::deriv_dep(Index i) {
return derivs[orig.dep_index[i]]; }
1124 global::replay::replay(
const global &orig,
global &target)
1125 : orig(orig), target(target) {
1126 TMBAD_ASSERT(&orig != &target);
1129 void global::replay::start() {
1131 if (&target != parent_glob) target.
ad_start();
1132 values = std::vector<Replay>(orig.
values.begin(), orig.
values.end());
1135 void global::replay::stop() {
1136 if (&target != parent_glob) target.
ad_stop();
1137 TMBAD_ASSERT(parent_glob ==
get_glob());
1143 void operator()(Index a, Index b) {
1144 Index n = b - a + 1;
1148 } F = {derivs.data()};
1152 void global::replay::clear_deriv() {
1153 derivs.resize(values.size());
1154 std::fill(derivs.begin(), derivs.end(), Replay(0));
1158 add_updatable_derivs(I);
1162 void global::replay::forward(
bool inv_tags,
bool dep_tags, Position start,
1163 const std::vector<bool> &node_filter) {
1164 TMBAD_ASSERT(&target ==
get_glob());
1166 for (
size_t i = 0; i < orig.
inv_index.size(); i++)
1167 value_inv(i).Independent();
1170 if (node_filter.size() > 0) {
1171 TMBAD_ASSERT(node_filter.size() == orig.
opstack.size());
1177 for (
size_t i = 0; i < orig.
dep_index.size(); i++) value_dep(i).Dependent();
1181 void global::replay::reverse(
bool dep_tags,
bool inv_tags, Position start,
1182 const std::vector<bool> &node_filter) {
1183 TMBAD_ASSERT(&target ==
get_glob());
1185 for (
size_t i = 0; i < orig.
dep_index.size(); i++)
1186 deriv_dep(i).Independent();
1189 if (node_filter.size() > 0) {
1190 TMBAD_ASSERT(node_filter.size() == orig.
opstack.size());
1196 std::fill(derivs.begin(), derivs.begin() + start.ptr.second, Replay(0));
1198 for (
size_t i = 0; i < orig.
inv_index.size(); i++) deriv_inv(i).Dependent();
1202 void global::replay::forward_sub() {
1207 void global::replay::reverse_sub() {
1212 void global::replay::clear_deriv_sub() {
1217 add_updatable_derivs(I);
1221 void global::forward_replay(
bool inv_tags,
bool dep_tags) {
1223 global::replay replay(*
this, new_glob);
1225 replay.forward(inv_tags, dep_tags);
1230 void global::subgraph_cache_ptr()
const {
1231 if (subgraph_ptr.size() == opstack.size())
return;
1232 TMBAD_ASSERT(subgraph_ptr.size() < opstack.size());
1233 if (subgraph_ptr.size() == 0) subgraph_ptr.push_back(IndexPair(0, 0));
1234 for (
size_t i = subgraph_ptr.size(); i < opstack.size(); i++) {
1235 IndexPair ptr = subgraph_ptr[i - 1];
1236 opstack[i - 1]->increment(ptr);
1237 subgraph_ptr.push_back(ptr);
1241 void global::set_subgraph(
const std::vector<bool> &marks,
bool append) {
1242 std::vector<Index> v2o = var2op();
1243 if (!append) subgraph_seq.resize(0);
1244 Index previous = (Index)-1;
1245 for (
size_t i = 0; i < marks.size(); i++) {
1246 if (marks[i] && (v2o[i] != previous)) {
1247 subgraph_seq.push_back(v2o[i]);
1253 void global::mark_subgraph(std::vector<bool> &marks) {
1254 TMBAD_ASSERT(marks.size() == values.size());
1255 clear_array_subgraph(marks,
true);
1258 void global::unmark_subgraph(std::vector<bool> &marks) {
1259 TMBAD_ASSERT(marks.size() == values.size());
1260 clear_array_subgraph(marks,
false);
1263 void global::subgraph_trivial() {
1264 subgraph_cache_ptr();
1265 subgraph_seq.resize(0);
1266 for (
size_t i = 0; i < opstack.size(); i++) subgraph_seq.push_back(i);
1269 void global::clear_deriv_sub() { clear_array_subgraph(derivs); }
1271 global global::extract_sub(std::vector<Index> &var_remap,
global new_glob) {
1272 subgraph_cache_ptr();
1273 TMBAD_ASSERT(var_remap.size() == 0 || var_remap.size() == values.size());
1274 var_remap.resize(values.size(), 0);
1275 std::vector<bool> independent_variable = inv_marks();
1276 std::vector<bool> dependent_variable = dep_marks();
1278 for (
size_t j = 0; j < subgraph_seq.size(); j++) {
1279 Index i = subgraph_seq[j];
1280 args.
ptr = subgraph_ptr[i];
1282 size_t nout = opstack[i]->output_size();
1283 for (
size_t k = 0; k < nout; k++) {
1284 Index new_index = new_glob.
values.size();
1285 Index old_index = args.
output(k);
1286 var_remap[old_index] = new_index;
1287 new_glob.
values.push_back(args.
y(k));
1288 if (independent_variable[old_index]) {
1289 independent_variable[old_index] =
false;
1291 if (dependent_variable[old_index]) {
1292 dependent_variable[old_index] =
false;
1296 size_t nin = opstack[i]->input_size();
1297 for (
size_t k = 0; k < nin; k++) {
1298 new_glob.
inputs.push_back(var_remap[args.
input(k)]);
1304 independent_variable.flip();
1305 dependent_variable.flip();
1307 for (
size_t i = 0; i < inv_index.size(); i++) {
1308 Index old_var = inv_index[i];
1309 if (independent_variable[old_var])
1310 new_glob.
inv_index.push_back(var_remap[old_var]);
1312 for (
size_t i = 0; i < dep_index.size(); i++) {
1313 Index old_var = dep_index[i];
1314 if (dependent_variable[old_var])
1315 new_glob.
dep_index.push_back(var_remap[old_var]);
1320 void global::extract_sub_inplace(std::vector<bool> marks) {
1321 TMBAD_ASSERT(marks.size() == values.size());
1322 std::vector<Index> var_remap(values.size(), 0);
1323 std::vector<bool> independent_variable = inv_marks();
1324 std::vector<bool> dependent_variable = dep_marks();
1327 size_t s = 0, s_input = 0;
1328 std::vector<bool> opstack_deallocate(opstack.size(),
false);
1330 for (
size_t i = 0; i < opstack.size(); i++) {
1331 op_info info = opstack[i]->info();
1333 size_t nout = opstack[i]->output_size();
1334 bool any_marked_output = info.
test(op_info::elimination_protected);
1335 for (
size_t j = 0; j < nout; j++) {
1336 any_marked_output |= args.
y(j);
1338 if (info.
test(op_info::updating) && nout == 0) {
1340 opstack[i]->dependencies_updating(args, dep);
1341 any_marked_output |= dep.any(args.values);
1344 if (any_marked_output) {
1345 for (
size_t k = 0; k < nout; k++) {
1346 Index new_index = s;
1347 Index old_index = args.
output(k);
1348 var_remap[old_index] = new_index;
1349 values[new_index] = values[old_index];
1350 if (independent_variable[old_index]) {
1351 independent_variable[old_index] =
false;
1353 if (dependent_variable[old_index]) {
1354 dependent_variable[old_index] =
false;
1359 size_t nin = opstack[i]->input_size();
1360 for (
size_t k = 0; k < nin; k++) {
1361 inputs[s_input] = var_remap[args.
input(k)];
1365 opstack[i]->increment(args.
ptr);
1366 if (!any_marked_output) {
1367 opstack_deallocate[i] =
true;
1371 independent_variable.flip();
1372 dependent_variable.flip();
1373 std::vector<Index> new_inv_index;
1374 for (
size_t i = 0; i < inv_index.size(); i++) {
1375 Index old_var = inv_index[i];
1376 if (independent_variable[old_var])
1377 new_inv_index.push_back(var_remap[old_var]);
1379 inv_index = new_inv_index;
1380 std::vector<Index> new_dep_index;
1381 for (
size_t i = 0; i < dep_index.size(); i++) {
1382 Index old_var = dep_index[i];
1383 if (dependent_variable[old_var])
1384 new_dep_index.push_back(var_remap[old_var]);
1386 dep_index = new_dep_index;
1388 inputs.resize(s_input);
1391 for (
size_t i = 0; i < opstack.size(); i++) {
1392 if (opstack_deallocate[i]) {
1393 opstack[i]->deallocate();
1395 opstack[k] = opstack[i];
1401 if (opstack.any.test(op_info::dynamic)) this->forward();
1405 std::vector<Index> var_remap;
1406 return extract_sub(var_remap);
1409 std::vector<Index> global::var2op() {
1410 std::vector<Index> var2op(values.size());
1413 for (
size_t i = 0; i < opstack.size(); i++) {
1414 opstack[i]->increment(args.ptr);
1415 for (; j < (size_t)args.ptr.second; j++) {
1422 std::vector<bool> global::var2op(
const std::vector<bool> &values) {
1423 std::vector<bool> ans(opstack.size(),
false);
1426 for (
size_t i = 0; i < opstack.size(); i++) {
1427 opstack[i]->increment(args.ptr);
1428 for (; j < (size_t)args.ptr.second; j++) {
1429 ans[i] = ans[i] || values[j];
1435 std::vector<Index> global::op2var(
const std::vector<Index> &seq) {
1436 std::vector<bool> seq_mark = mark_space(opstack.size(), seq);
1437 std::vector<Index> ans;
1440 for (
size_t i = 0; i < opstack.size(); i++) {
1441 opstack[i]->increment(args.
ptr);
1442 for (; j < (size_t)args.
ptr.second; j++) {
1443 if (seq_mark[i]) ans.push_back(j);
1449 std::vector<bool> global::op2var(
const std::vector<bool> &seq_mark) {
1450 std::vector<bool> ans(values.size());
1453 for (
size_t i = 0; i < opstack.size(); i++) {
1454 opstack[i]->increment(args.ptr);
1455 for (; j < (size_t)args.ptr.second; j++) {
1456 if (seq_mark[i]) ans[j] =
true;
1462 std::vector<Index> global::op2idx(
const std::vector<Index> &var_subset,
1464 std::vector<Index> v2o = var2op();
1465 std::vector<Index> op2idx(opstack.size(), NA);
1466 for (
size_t i = var_subset.size(); i > 0;) {
1468 op2idx[v2o[var_subset[i]]] = i;
1473 std::vector<bool> global::mark_space(
size_t n,
const std::vector<Index> ind) {
1474 std::vector<bool> mark(n,
false);
1475 for (
size_t i = 0; i < ind.size(); i++) {
1476 mark[ind[i]] =
true;
1481 std::vector<bool> global::inv_marks() {
1482 return mark_space(values.size(), inv_index);
1485 std::vector<bool> global::dep_marks() {
1486 return mark_space(values.size(), dep_index);
1489 std::vector<bool> global::subgraph_marks() {
1490 return mark_space(opstack.size(), subgraph_seq);
1493 global::append_edges::append_edges(
size_t &i,
size_t num_nodes,
1494 const std::vector<bool> &keep_var,
1495 std::vector<Index> &var2op,
1496 std::vector<IndexPair> &edges)
1501 op_marks(num_nodes,
false),
1504 void global::append_edges::operator()(Index dep_j) {
1505 if (keep_var[dep_j]) {
1506 size_t k = var2op[dep_j];
1507 if (i != k && !op_marks[k]) {
1512 edges.push_back(edge);
1518 void global::append_edges::start_iteration() { pos = edges.size(); }
1520 void global::append_edges::end_iteration() {
1521 size_t n = edges.size() - pos;
1522 for (
size_t j = 0; j < n; j++) op_marks[edges[pos + j].first] =
false;
1525 graph global::build_graph(
bool transpose,
const std::vector<bool> &keep_var) {
1526 TMBAD_ASSERT(keep_var.size() == values.size());
1528 std::vector<Index> var2op = this->var2op();
1530 bool any_updating =
false;
1533 std::vector<IndexPair> edges;
1536 append_edges F(i, opstack.size(), keep_var, var2op, edges);
1537 for (; i < opstack.size(); i++) {
1538 any_updating |= opstack[i]->info().test(op_info::updating);
1540 opstack[i]->dependencies(args, dep);
1541 F.start_iteration();
1544 opstack[i]->increment(args.
ptr);
1547 size_t begin = edges.size();
1550 for (; i < opstack.size(); i++) {
1552 opstack[i]->dependencies_updating(args, dep);
1553 F.start_iteration();
1556 opstack[i]->increment(args.
ptr);
1558 for (
size_t j = begin; j < edges.size(); j++)
1559 std::swap(edges[j].first, edges[j].second);
1563 for (
size_t j = 0; j < edges.size(); j++)
1564 std::swap(edges[j].first, edges[j].second);
1567 graph G(opstack.size(), edges);
1569 for (
size_t i = 0; i < inv_index.size(); i++)
1570 G.
inv2op.push_back(var2op[inv_index[i]]);
1571 for (
size_t i = 0; i < dep_index.size(); i++)
1572 G.dep2op.push_back(var2op[dep_index[i]]);
1576 graph global::forward_graph(std::vector<bool> keep_var) {
1577 if (keep_var.size() == 0) {
1578 keep_var.resize(values.size(),
true);
1580 TMBAD_ASSERT(values.size() == keep_var.size());
1581 return build_graph(
false, keep_var);
1584 graph global::reverse_graph(std::vector<bool> keep_var) {
1585 if (keep_var.size() == 0) {
1586 keep_var.resize(values.size(),
true);
1588 TMBAD_ASSERT(values.size() == keep_var.size());
1589 return build_graph(
true, keep_var);
1592 bool global::identical(
const global &other)
const {
1593 if (inv_index != other.
inv_index)
return false;
1595 if (dep_index != other.
dep_index)
return false;
1597 if (opstack.size() != other.
opstack.size())
return false;
1599 for (
size_t i = 0; i < opstack.size(); i++) {
1600 if (opstack[i]->identifier() != other.
opstack[i]->identifier())
1604 if (inputs != other.
inputs)
return false;
1606 if (values.size() != other.
values.size())
return false;
1609 IndexPair ptr(0, 0);
1610 for (
size_t i = 0; i < opstack.size(); i++) {
1611 if (opstack[i] == constant) {
1612 if (values[ptr.second] != other.
values[ptr.second])
return false;
1615 opstack[i]->increment(ptr);
1621 hash_t global::hash()
const {
1624 hash(h, inv_index.size());
1626 for (
size_t i = 0; i < inv_index.size(); i++) hash(h, inv_index[i]);
1629 hash(h, dep_index.size());
1631 for (
size_t i = 0; i < dep_index.size(); i++) hash(h, dep_index[i]);
1634 hash(h, opstack.size());
1636 for (
size_t i = 0; i < opstack.size(); i++) hash(h, opstack[i]);
1639 hash(h, inputs.size());
1641 for (
size_t i = 0; i < inputs.size(); i++) hash(h, inputs[i]);
1644 hash(h, values.size());
1647 IndexPair ptr(0, 0);
1648 for (
size_t i = 0; i < opstack.size(); i++) {
1649 if (opstack[i] == constant) {
1650 hash(h, values[ptr.second]);
1653 opstack[i]->increment(ptr);
1660 std::vector<Index> opstack_id;
1662 std::vector<size_t> tmp(opstack.size());
1663 for (
size_t i = 0; i < tmp.size(); i++)
1664 tmp[i] = (
size_t)opstack[i]->identifier();
1665 opstack_id = radix::first_occurance<Index>(tmp);
1666 hash_t spread = (hash_t(1) << (
sizeof(hash_t) * 4)) - 1;
1667 for (
size_t i = 0; i < opstack_id.size(); i++)
1668 opstack_id[i] = (opstack_id[i] + 1) * spread;
1671 std::vector<hash_t> hash_vec(values.size(), 37);
1677 bool have_inv_seed = (cfg.
inv_seed.size() > 0);
1678 if (have_inv_seed) {
1679 TMBAD_ASSERT(cfg.
inv_seed.size() == inv_index.size());
1681 for (
size_t i = 0; i < inv_index.size(); i++) {
1682 hash_vec[inv_index[i]] += (have_inv_seed ? cfg.
inv_seed[i] + 1 : (i + 1));
1687 IndexPair &ptr = args.
ptr;
1688 for (
size_t i = 0; i < opstack.size(); i++) {
1689 if (opstack[i] == inv) {
1690 opstack[i]->increment(ptr);
1695 opstack[i]->dependencies(args, dep);
1698 for (
size_t j = 0; j < dep.size(); j++) {
1700 h = hash_vec[dep[0]];
1702 hash(h, hash_vec[dep[j]]);
1707 hash(h, opstack[i]->identifier());
1710 hash(h, opstack_id[i]);
1715 hash(h, values[ptr.second]);
1718 hash(h, values[ptr.second] > 0);
1722 size_t noutput = opstack[i]->output_size();
1723 for (
size_t j = 0; j < noutput; j++) {
1727 opstack[i]->increment(ptr);
1729 if (!cfg.
reduce)
return hash_vec;
1730 std::vector<hash_t> ans(dep_index.size());
1731 for (
size_t j = 0; j < dep_index.size(); j++) {
1732 ans[j] = hash_vec[dep_index[j]];
1737 std::vector<hash_t> global::hash_sweep(
bool weak)
const {
1744 return hash_sweep(cfg);
1747 void global::eliminate() {
1748 this->shrink_to_fit();
1750 std::vector<bool> marks;
1751 marks.resize(values.size(),
false);
1753 for (
size_t i = 0; i < inv_index.size(); i++) marks[inv_index[i]] =
true;
1754 for (
size_t i = 0; i < dep_index.size(); i++) marks[dep_index[i]] =
true;
1759 set_subgraph(marks);
1761 *
this = extract_sub();
1763 this->extract_sub_inplace(marks);
1764 this->shrink_to_fit();
1767 global::print_config::print_config() : prefix(
""), mark(
"*"), depth(0) {}
1773 IndexPair ptr(0, 0);
1774 std::vector<bool> sgm = subgraph_marks();
1775 bool have_subgraph = (subgraph_seq.size() > 0);
1779 cfg2.prefix = cfg.prefix +
"##";
1780 Rcout << cfg.prefix;
1781 Rcout << setw(7) <<
"OpName:" << setw(7 + have_subgraph)
1782 <<
"Node:" << setw(13) <<
"Value:" << setw(13) <<
"Deriv:" << setw(13)
1787 for (
size_t i = 0; i < opstack.size(); i++) {
1788 Rcout << cfg.prefix;
1789 Rcout << setw(7) << opstack[i]->op_name();
1790 if (have_subgraph) {
1796 Rcout << setw(7) << i;
1797 int numvar = opstack[i]->output_size();
1798 for (
int j = 0; j < numvar + (numvar == 0); j++) {
1799 if (j > 0) Rcout << cfg.prefix;
1800 Rcout << setw((7 + 7) * (j > 0) + 13);
1807 if (derivs.size() == values.size())
1821 IndexPair ptr_old = ptr;
1822 opstack[i]->increment(ptr);
1823 int ninput = ptr.first - ptr_old.first;
1824 for (
int k = 0; k < ninput; k++) {
1825 if (k == 0) Rcout <<
" ";
1826 Rcout <<
" " << inputs[ptr_old.first + k];
1834 if (cfg.depth > 0) opstack[i]->print(cfg2);
1840 global::DynamicInputOutputOperator::DynamicInputOutputOperator(Index ninput,
1842 : ninput_(ninput), noutput_(noutput) {}
1844 Index global::DynamicInputOutputOperator::input_size()
const {
1845 return this->ninput_;
1848 Index global::DynamicInputOutputOperator::output_size()
const {
1849 return this->noutput_;
1852 const char *global::InvOp::op_name() {
return "InvOp"; }
1854 const char *global::DepOp::op_name() {
return "DepOp"; }
1857 args.
y(0).addToTape();
1860 const char *global::ConstOp::op_name() {
return "ConstOp"; }
1863 if (args.const_literals) {
1864 args.
y(0) = args.y_const(0);
1868 global::DataOp::DataOp(Index n) { Base::noutput = n; }
1870 const char *global::DataOp::op_name() {
return "DataOp"; }
1874 global::ZeroOp::ZeroOp(Index n) { Base::noutput = n; }
1876 const char *global::ZeroOp::op_name() {
return "ZeroOp"; }
1880 void global::ZeroOp::operator()(Replay *x, Index n) {
1883 for (
size_t i = 0; i < n; i++) x[i] = y[i];
1886 global::NullOp::NullOp() {}
1888 const char *global::NullOp::op_name() {
return "NullOp"; }
1890 global::NullOp2::NullOp2(Index ninput, Index noutput)
1891 : global::DynamicInputOutputOperator(ninput, noutput) {}
1893 const char *global::NullOp2::op_name() {
return "NullOp2"; }
1895 global::RefOp::RefOp(
global *glob, Index i) : glob(glob), i(i) {}
1916 args.
dx(0) += args.
dy(0);
1920 const char *global::RefOp::op_name() {
return "RefOp"; }
1929 void global::set_fuse(
bool flag) { fuse = flag; }
1933 while (this->opstack.size() > 0) {
1934 OperatorPure *OpTry = this->Fuse(this->opstack.back(), pOp);
1935 if (OpTry == NULL)
break;
1937 this->opstack.pop_back();
1942 this->opstack.push_back(pOp);
1945 bool global::ad_plain::initialized()
const {
return index != NA; }
1947 bool global::ad_plain::on_some_tape()
const {
return initialized(); }
1949 void global::ad_plain::addToTape()
const { TMBAD_ASSERT(initialized()); }
1951 global *global::ad_plain::glob()
const {
1952 return (on_some_tape() ?
get_glob() : NULL);
1955 void global::ad_plain::override_by(
const ad_plain &x)
const {}
1957 global::ad_plain::ad_plain() : index(NA) {}
1959 global::ad_plain::ad_plain(Scalar x) {
1963 global::ad_plain::ad_plain(
ad_aug x) {
1968 Replay global::ad_plain::CopyOp::eval(Replay x0) {
return x0.copy(); }
1970 const char *global::ad_plain::CopyOp::op_name() {
return "CopyOp"; }
1972 ad_plain global::ad_plain::copy()
const {
1977 Replay global::ad_plain::ValOp::eval(Replay x0) {
return x0.copy0(); }
1979 void global::ad_plain::ValOp::dependencies(
Args<> &args,
1980 Dependencies &dep)
const {}
1982 const char *global::ad_plain::ValOp::op_name() {
return "ValOp"; }
1984 ad_plain global::ad_plain::copy0()
const {
1989 ad_plain global::ad_plain::operator+(
const ad_plain &other)
const {
1995 ad_plain global::ad_plain::operator-(
const ad_plain &other)
const {
2012 ad_plain global::ad_plain::operator/(
const ad_plain &other)
const {
2017 const char *global::ad_plain::NegOp::op_name() {
return "NegOp"; }
2019 ad_plain global::ad_plain::operator-()
const {
2024 ad_plain &global::ad_plain::operator+=(
const ad_plain &other) {
2025 *
this = *
this + other;
2029 ad_plain &global::ad_plain::operator-=(
const ad_plain &other) {
2030 *
this = *
this - other;
2034 ad_plain &global::ad_plain::operator*=(
const ad_plain &other) {
2035 *
this = *
this * other;
2039 ad_plain &global::ad_plain::operator/=(
const ad_plain &other) {
2040 *
this = *
this / other;
2044 void global::ad_plain::Dependent() {
2049 void global::ad_plain::Independent() {
2050 Scalar val = (index == NA ? NAN : this->Value());
2055 Scalar &global::ad_plain::Value() {
return get_glob()->
values[index]; }
2057 Scalar global::ad_plain::Value()
const {
return get_glob()->
values[index]; }
2059 Scalar global::ad_plain::Value(
global *glob)
const {
2060 return glob->
values[index];
2063 Scalar &global::ad_plain::Deriv() {
return get_glob()->
derivs[index]; }
2065 void global::ad_start() {
2066 TMBAD_ASSERT2(!in_use,
"Tape already in use");
2067 TMBAD_ASSERT(parent_glob == NULL);
2068 parent_glob = global_ptr[TMBAD_THREAD_NUM];
2069 global_ptr[TMBAD_THREAD_NUM] =
this;
2073 void global::ad_stop() {
2074 TMBAD_ASSERT2(in_use,
"Tape not in use");
2075 global_ptr[TMBAD_THREAD_NUM] = parent_glob;
2080 void global::Independent(std::vector<ad_plain> &x) {
2081 for (
size_t i = 0; i < x.size(); i++) {
2086 global::ad_segment::ad_segment() : n(0), c(0) {}
2097 : x(x), n(r * c), c(c) {}
2101 if (zero_check && all_zero(x, n))
return;
2102 if (all_constant(x, n)) {
2104 size_t m = glob->
values.size();
2107 for (
size_t i = 0; i < n; i++) glob->
values[m + i] = x[i].Value();
2111 if (!is_contiguous(x, n)) {
2113 this->x = x[0].copy();
2114 for (
size_t i = 1; i < n; i++) x[i].copy();
2116 TMBAD_ASSERT2(after - before == n,
2117 "Each invocation of copy() should construct a new variable");
2120 if (n > 0) this->x = x[0];
2123 bool global::ad_segment::identicalZero() {
return !x.initialized(); }
2125 bool global::ad_segment::all_on_active_tape(Replay *x,
size_t n) {
2127 for (
size_t i = 0; i < n; i++) {
2128 bool ok = x[i].on_some_tape() && (x[i].glob() == cur_glob);
2129 if (!ok)
return false;
2134 bool global::ad_segment::is_contiguous(Replay *x,
size_t n) {
2135 if (!all_on_active_tape(x, n))
return false;
2136 for (
size_t i = 1; i < n; i++) {
2137 if (x[i].index() != x[i - 1].index() + 1)
return false;
2142 bool global::ad_segment::all_zero(Replay *x,
size_t n) {
2143 for (
size_t i = 0; i < n; i++) {
2144 if (!x[i].identicalZero())
return false;
2149 bool global::ad_segment::all_constant(Replay *x,
size_t n) {
2150 for (
size_t i = 0; i < n; i++) {
2151 if (!x[i].constant())
return false;
2156 size_t global::ad_segment::size()
const {
return n; }
2158 size_t global::ad_segment::rows()
const {
return n / c; }
2160 size_t global::ad_segment::cols()
const {
return c; }
2162 ad_plain global::ad_segment::operator[](
size_t i)
const {
2164 ans.index = x.index + i;
2168 ad_plain global::ad_segment::offset()
const {
return x; }
2170 Index global::ad_segment::index()
const {
return x.index; }
2175 return on_some_tape() && (this->glob() ==
get_glob());
2182 Index global::ad_aug::index()
const {
return taped_value.index; }
2185 return (on_some_tape() ? data.glob : NULL);
2190 return taped_value.Value(this->data.glob);
2225 while (cur_glob != NULL) {
2226 if (cur_glob == glob)
return true;
2251 return constant() && data.value == Scalar(0);
2255 return constant() && data.value == Scalar(1);
2263 if (
constant() && other.
constant())
return (data.value == other.data.value);
2265 if (
glob() == other.glob())
2271 if (
bothConstant(other))
return Scalar(this->data.value + other.data.value);
2274 return ad_plain(*
this) + ad_plain(other);
2278 if (
bothConstant(other))
return Scalar(this->data.value - other.data.value);
2281 if (this->
identical(other))
return Scalar(0);
2282 return ad_plain(*
this) - ad_plain(other);
2286 if (this->
constant())
return Scalar(-(this->data.value));
2287 return -ad_plain(*
this);
2291 if (
bothConstant(other))
return Scalar(this->data.value * other.data.value);
2296 if (this->
constant())
return ad_plain(other) * Scalar(this->data.value);
2297 if (other.
constant())
return ad_plain(*
this) * Scalar(other.data.value);
2298 return ad_plain(*
this) * ad_plain(other);
2302 if (
bothConstant(other))
return Scalar(this->data.value / other.data.value);
2305 return ad_plain(*
this) / ad_plain(other);
2309 *
this = *
this + other;
2314 *
this = *
this - other;
2319 *
this = *
this * other;
2324 *
this = *
this / other;
2347 Scalar &global::ad_aug::Deriv() {
return taped_value.Deriv(); }
2349 void global::Independent(std::vector<ad_aug> &x) {
2350 for (
size_t i = 0; i < x.size(); i++) {
2355 std::ostream &operator<<(std::ostream &os,
const global::ad_plain &x) {
2360 std::ostream &operator<<(std::ostream &os,
const global::ad_aug &x) {
2363 os <<
"value=" << x.data.glob->values[x.
taped_value.index] <<
", ";
2365 os <<
"tape=" << x.data.glob;
2367 os <<
"const=" << x.data.value;
2373 ad_plain_index::ad_plain_index(
const Index &i) { this->index = i; }
2375 ad_plain_index::ad_plain_index(
const ad_plain &x) : ad_plain(x) {}
2379 ad_aug_index::ad_aug_index(
const ad_aug &x) :
ad_aug(x) {}
2381 ad_aug_index::ad_aug_index(
const ad_plain &x) :
ad_aug(x) {}
2383 Scalar
Value(Scalar x) {
return x; }
2393 bool operator<(
const double &x,
const ad_adapt &y) {
return x < y.Value(); }
2395 bool operator<=(
const double &x,
const ad_adapt &y) {
return x <= y.Value(); }
2397 bool operator>(
const double &x,
const ad_adapt &y) {
return x > y.Value(); }
2399 bool operator>=(
const double &x,
const ad_adapt &y) {
return x >= y.Value(); }
2401 bool operator==(
const double &x,
const ad_adapt &y) {
return x == y.Value(); }
2403 bool operator!=(
const double &x,
const ad_adapt &y) {
return x != y.Value(); }
2405 Writer floor(
const Writer &x) {
2410 const char *FloorOp::op_name() {
return "FloorOp"; }
2411 ad_plain floor(
const ad_plain &x) {
2416 return Scalar(floor(x.
Value()));
2418 return floor(ad_plain(x));
2421 Writer ceil(
const Writer &x) {
2426 const char *CeilOp::op_name() {
return "CeilOp"; }
2430 return Scalar(ceil(x.
Value()));
2432 return ceil(ad_plain(x));
2435 Writer trunc(
const Writer &x) {
2440 const char *TruncOp::op_name() {
return "TruncOp"; }
2441 ad_plain trunc(
const ad_plain &x) {
2446 return Scalar(trunc(x.
Value()));
2448 return trunc(ad_plain(x));
2451 Writer round(
const Writer &x) {
2456 const char *RoundOp::op_name() {
return "RoundOp"; }
2457 ad_plain round(
const ad_plain &x) {
2462 return Scalar(round(x.
Value()));
2464 return round(ad_plain(x));
2467 double sign(
const double &x) {
return (x >= 0) - (x < 0); }
2469 Writer sign(
const Writer &x) {
2474 const char *SignOp::op_name() {
return "SignOp"; }
2478 return Scalar(sign(x.
Value()));
2480 return sign(ad_plain(x));
2483 double ge0(
const double &x) {
return (x >= 0); }
2485 double lt0(
const double &x) {
return (x < 0); }
2487 Writer ge0(
const Writer &x) {
2492 const char *Ge0Op::op_name() {
return "Ge0Op"; }
2496 return Scalar(ge0(x.
Value()));
2498 return ge0(ad_plain(x));
2501 Writer lt0(
const Writer &x) {
2506 const char *Lt0Op::op_name() {
return "Lt0Op"; }
2510 return Scalar(lt0(x.
Value()));
2512 return lt0(ad_plain(x));
2515 Writer fabs(
const Writer &x) {
2521 typedef Scalar Type;
2522 if (args.
dy(0) != Type(0)) args.
dx(0) += args.
dy(0) * sign(args.
x(0));
2524 const char *AbsOp::op_name() {
return "AbsOp"; }
2528 return Scalar(fabs(x.
Value()));
2530 return fabs(ad_plain(x));
2534 Writer sin(
const Writer &x) {
2540 typedef Scalar Type;
2541 if (args.
dy(0) != Type(0)) args.
dx(0) += args.
dy(0) * cos(args.
x(0));
2543 const char *SinOp::op_name() {
return "SinOp"; }
2547 return Scalar(sin(x.
Value()));
2549 return sin(ad_plain(x));
2553 Writer cos(
const Writer &x) {
2559 typedef Scalar Type;
2560 if (args.
dy(0) != Type(0)) args.
dx(0) += args.
dy(0) * -sin(args.
x(0));
2562 const char *CosOp::op_name() {
return "CosOp"; }
2566 return Scalar(cos(x.
Value()));
2568 return cos(ad_plain(x));
2572 Writer exp(
const Writer &x) {
2578 typedef Scalar Type;
2579 if (args.
dy(0) != Type(0)) args.
dx(0) += args.
dy(0) * args.
y(0);
2581 const char *ExpOp::op_name() {
return "ExpOp"; }
2585 return Scalar(exp(x.
Value()));
2587 return exp(ad_plain(x));
2591 Writer log(
const Writer &x) {
2597 typedef Scalar Type;
2598 if (args.
dy(0) != Type(0)) args.
dx(0) += args.
dy(0) * Type(1.) / args.
x(0);
2600 const char *LogOp::op_name() {
return "LogOp"; }
2604 return Scalar(log(x.
Value()));
2606 return log(ad_plain(x));
2610 Writer sqrt(
const Writer &x) {
2616 typedef Scalar Type;
2617 if (args.
dy(0) != Type(0)) args.
dx(0) += args.
dy(0) * Type(0.5) / args.
y(0);
2619 const char *SqrtOp::op_name() {
return "SqrtOp"; }
2623 return Scalar(sqrt(x.
Value()));
2625 return sqrt(ad_plain(x));
2629 Writer tan(
const Writer &x) {
2635 typedef Scalar Type;
2636 if (args.
dy(0) != Type(0))
2637 args.
dx(0) += args.
dy(0) * Type(1.) / (cos(args.
x(0)) * cos(args.
x(0)));
2639 const char *TanOp::op_name() {
return "TanOp"; }
2643 return Scalar(tan(x.
Value()));
2645 return tan(ad_plain(x));
2649 Writer sinh(
const Writer &x) {
2655 typedef Scalar Type;
2656 if (args.
dy(0) != Type(0)) args.
dx(0) += args.
dy(0) * cosh(args.
x(0));
2658 const char *SinhOp::op_name() {
return "SinhOp"; }
2662 return Scalar(sinh(x.
Value()));
2664 return sinh(ad_plain(x));
2668 Writer cosh(
const Writer &x) {
2674 typedef Scalar Type;
2675 if (args.
dy(0) != Type(0)) args.
dx(0) += args.
dy(0) * sinh(args.
x(0));
2677 const char *CoshOp::op_name() {
return "CoshOp"; }
2681 return Scalar(cosh(x.
Value()));
2683 return cosh(ad_plain(x));
2687 Writer tanh(
const Writer &x) {
2693 typedef Scalar Type;
2694 if (args.
dy(0) != Type(0))
2695 args.
dx(0) += args.
dy(0) * Type(1.) / (cosh(args.
x(0)) * cosh(args.
x(0)));
2697 const char *TanhOp::op_name() {
return "TanhOp"; }
2701 return Scalar(tanh(x.
Value()));
2703 return tanh(ad_plain(x));
2707 Writer expm1(
const Writer &x) {
2713 typedef Scalar Type;
2714 if (args.
dy(0) != Type(0)) args.
dx(0) += args.
dy(0) * args.
y(0) + Type(1.);
2716 const char *Expm1::op_name() {
return "Expm1"; }
2720 return Scalar(expm1(x.
Value()));
2722 return expm1(ad_plain(x));
2726 Writer log1p(
const Writer &x) {
2732 typedef Scalar Type;
2733 if (args.
dy(0) != Type(0))
2734 args.
dx(0) += args.
dy(0) * Type(1.) / (args.
x(0) + Type(1.));
2736 const char *Log1p::op_name() {
return "Log1p"; }
2740 return Scalar(log1p(x.
Value()));
2742 return log1p(ad_plain(x));
2746 Writer asin(
const Writer &x) {
2752 typedef Scalar Type;
2753 if (args.
dy(0) != Type(0))
2755 args.
dy(0) * Type(1.) / sqrt(Type(1.) - args.
x(0) * args.
x(0));
2757 const char *AsinOp::op_name() {
return "AsinOp"; }
2761 return Scalar(asin(x.
Value()));
2763 return asin(ad_plain(x));
2767 Writer acos(
const Writer &x) {
2773 typedef Scalar Type;
2774 if (args.
dy(0) != Type(0))
2776 args.
dy(0) * Type(-1.) / sqrt(Type(1.) - args.
x(0) * args.
x(0));
2778 const char *AcosOp::op_name() {
return "AcosOp"; }
2782 return Scalar(acos(x.
Value()));
2784 return acos(ad_plain(x));
2788 Writer atan(
const Writer &x) {
2794 typedef Scalar Type;
2795 if (args.
dy(0) != Type(0))
2796 args.
dx(0) += args.
dy(0) * Type(1.) / (Type(1.) + args.
x(0) * args.
x(0));
2798 const char *AtanOp::op_name() {
return "AtanOp"; }
2802 return Scalar(atan(x.
Value()));
2804 return atan(ad_plain(x));
2808 Writer asinh(
const Writer &x) {
2814 typedef Scalar Type;
2815 if (args.
dy(0) != Type(0))
2817 args.
dy(0) * Type(1.) / sqrt(args.
x(0) * args.
x(0) + Type(1.));
2819 const char *AsinhOp::op_name() {
return "AsinhOp"; }
2820 ad_plain asinh(
const ad_plain &x) {
2825 return Scalar(asinh(x.
Value()));
2827 return asinh(ad_plain(x));
2831 Writer acosh(
const Writer &x) {
2837 typedef Scalar Type;
2838 if (args.
dy(0) != Type(0))
2840 args.
dy(0) * Type(1.) / sqrt(args.
x(0) * args.
x(0) - Type(1.));
2842 const char *AcoshOp::op_name() {
return "AcoshOp"; }
2843 ad_plain acosh(
const ad_plain &x) {
2848 return Scalar(acosh(x.
Value()));
2850 return acosh(ad_plain(x));
2854 Writer atanh(
const Writer &x) {
2860 typedef Scalar Type;
2861 if (args.
dy(0) != Type(0))
2862 args.
dx(0) += args.
dy(0) * Type(1.) / (Type(1) - args.
x(0) * args.
x(0));
2864 const char *AtanhOp::op_name() {
return "AtanhOp"; }
2865 ad_plain atanh(
const ad_plain &x) {
2870 return Scalar(atanh(x.
Value()));
2872 return atanh(ad_plain(x));
2876 Writer pow(
const Writer &x1,
const Writer &x2) {
2879 x1 +
"," + x2 +
")";
2881 const char *PowOp::op_name() {
return "PowOp"; }
2882 ad_plain pow(
const ad_plain &x1,
const ad_plain &x2) {
2889 return pow(ad_plain(x1), ad_plain(x2));
2895 Writer atan2(
const Writer &x1,
const Writer &x2) {
2898 x1 +
"," + x2 +
")";
2900 const char *Atan2::op_name() {
return "Atan2"; }
2901 ad_plain atan2(
const ad_plain &x1,
const ad_plain &x2) {
2906 return Scalar(atan2(x1.
Value(), x2.
Value()));
2908 return atan2(ad_plain(x1), ad_plain(x2));
2914 Writer max(
const Writer &x1,
const Writer &x2) {
2917 x1 +
"," + x2 +
")";
2919 const char *MaxOp::op_name() {
return "MaxOp"; }
2920 ad_plain max(
const ad_plain &x1,
const ad_plain &x2) {
2927 return max(ad_plain(x1), ad_plain(x2));
2933 Writer min(
const Writer &x1,
const Writer &x2) {
2936 x1 +
"," + x2 +
")";
2938 const char *MinOp::op_name() {
return "MinOp"; }
2939 ad_plain min(
const ad_plain &x1,
const ad_plain &x2) {
2946 return min(ad_plain(x1), ad_plain(x2));
2952 if (args.
x(0) == args.
x(1)) {
2953 args.
y(0) = args.
x(2);
2955 args.
y(0) = args.
x(3);
2959 if (args.
x(0) == args.
x(1)) {
2960 args.
dx(2) += args.
dy(0);
2962 args.
dx(3) += args.
dy(0);
2966 args.
y(0) = CondExpEq(args.
x(0), args.
x(1), args.
x(2), args.
x(3));
2970 args.
dx(2) += CondExpEq(args.
x(0), args.
x(1), args.
dy(0), zero);
2971 args.
dx(3) += CondExpEq(args.
x(0), args.
x(1), zero, args.
dy(0));
2975 w <<
"if (" << args.
x(0) <<
"==" << args.
x(1) <<
") ";
2976 args.
y(0) = args.
x(2);
2978 args.
y(0) = args.
x(3);
2982 w <<
"if (" << args.
x(0) <<
"==" << args.
x(1) <<
") ";
2983 args.
dx(2) += args.
dy(0);
2985 args.
dx(3) += args.
dy(0);
2987 const char *CondExpEqOp::op_name() {
2991 Scalar CondExpEq(
const Scalar &x0,
const Scalar &x1,
const Scalar &x2,
2998 ad_plain CondExpEq(
const ad_plain &x0,
const ad_plain &x1,
const ad_plain &x2,
2999 const ad_plain &x3) {
3001 std::vector<ad_plain> x(4);
3017 return CondExpEq(ad_plain(x0), ad_plain(x1), ad_plain(x2), ad_plain(x3));
3021 if (args.
x(0) != args.
x(1)) {
3022 args.
y(0) = args.
x(2);
3024 args.
y(0) = args.
x(3);
3028 if (args.
x(0) != args.
x(1)) {
3029 args.
dx(2) += args.
dy(0);
3031 args.
dx(3) += args.
dy(0);
3035 args.
y(0) = CondExpNe(args.
x(0), args.
x(1), args.
x(2), args.
x(3));
3039 args.
dx(2) += CondExpNe(args.
x(0), args.
x(1), args.
dy(0), zero);
3040 args.
dx(3) += CondExpNe(args.
x(0), args.
x(1), zero, args.
dy(0));
3044 w <<
"if (" << args.
x(0) <<
"!=" << args.
x(1) <<
") ";
3045 args.
y(0) = args.
x(2);
3047 args.
y(0) = args.
x(3);
3051 w <<
"if (" << args.
x(0) <<
"!=" << args.
x(1) <<
") ";
3052 args.
dx(2) += args.
dy(0);
3054 args.
dx(3) += args.
dy(0);
3056 const char *CondExpNeOp::op_name() {
3060 Scalar CondExpNe(
const Scalar &x0,
const Scalar &x1,
const Scalar &x2,
3067 ad_plain CondExpNe(
const ad_plain &x0,
const ad_plain &x1,
const ad_plain &x2,
3068 const ad_plain &x3) {
3070 std::vector<ad_plain> x(4);
3086 return CondExpNe(ad_plain(x0), ad_plain(x1), ad_plain(x2), ad_plain(x3));
3090 if (args.
x(0) > args.
x(1)) {
3091 args.
y(0) = args.
x(2);
3093 args.
y(0) = args.
x(3);
3097 if (args.
x(0) > args.
x(1)) {
3098 args.
dx(2) += args.
dy(0);
3100 args.
dx(3) += args.
dy(0);
3104 args.
y(0) = CondExpGt(args.
x(0), args.
x(1), args.
x(2), args.
x(3));
3108 args.
dx(2) += CondExpGt(args.
x(0), args.
x(1), args.
dy(0), zero);
3109 args.
dx(3) += CondExpGt(args.
x(0), args.
x(1), zero, args.
dy(0));
3113 w <<
"if (" << args.
x(0) <<
">" << args.
x(1) <<
") ";
3114 args.
y(0) = args.
x(2);
3116 args.
y(0) = args.
x(3);
3120 w <<
"if (" << args.
x(0) <<
">" << args.
x(1) <<
") ";
3121 args.
dx(2) += args.
dy(0);
3123 args.
dx(3) += args.
dy(0);
3125 const char *CondExpGtOp::op_name() {
3129 Scalar CondExpGt(
const Scalar &x0,
const Scalar &x1,
const Scalar &x2,
3136 ad_plain CondExpGt(
const ad_plain &x0,
const ad_plain &x1,
const ad_plain &x2,
3137 const ad_plain &x3) {
3139 std::vector<ad_plain> x(4);
3155 return CondExpGt(ad_plain(x0), ad_plain(x1), ad_plain(x2), ad_plain(x3));
3159 if (args.
x(0) < args.
x(1)) {
3160 args.
y(0) = args.
x(2);
3162 args.
y(0) = args.
x(3);
3166 if (args.
x(0) < args.
x(1)) {
3167 args.
dx(2) += args.
dy(0);
3169 args.
dx(3) += args.
dy(0);
3173 args.
y(0) = CondExpLt(args.
x(0), args.
x(1), args.
x(2), args.
x(3));
3177 args.
dx(2) += CondExpLt(args.
x(0), args.
x(1), args.
dy(0), zero);
3178 args.
dx(3) += CondExpLt(args.
x(0), args.
x(1), zero, args.
dy(0));
3182 w <<
"if (" << args.
x(0) <<
"<" << args.
x(1) <<
") ";
3183 args.
y(0) = args.
x(2);
3185 args.
y(0) = args.
x(3);
3189 w <<
"if (" << args.
x(0) <<
"<" << args.
x(1) <<
") ";
3190 args.
dx(2) += args.
dy(0);
3192 args.
dx(3) += args.
dy(0);
3194 const char *CondExpLtOp::op_name() {
3198 Scalar CondExpLt(
const Scalar &x0,
const Scalar &x1,
const Scalar &x2,
3205 ad_plain CondExpLt(
const ad_plain &x0,
const ad_plain &x1,
const ad_plain &x2,
3206 const ad_plain &x3) {
3208 std::vector<ad_plain> x(4);
3224 return CondExpLt(ad_plain(x0), ad_plain(x1), ad_plain(x2), ad_plain(x3));
3228 if (args.
x(0) >= args.
x(1)) {
3229 args.
y(0) = args.
x(2);
3231 args.
y(0) = args.
x(3);
3235 if (args.
x(0) >= args.
x(1)) {
3236 args.
dx(2) += args.
dy(0);
3238 args.
dx(3) += args.
dy(0);
3242 args.
y(0) = CondExpGe(args.
x(0), args.
x(1), args.
x(2), args.
x(3));
3246 args.
dx(2) += CondExpGe(args.
x(0), args.
x(1), args.
dy(0), zero);
3247 args.
dx(3) += CondExpGe(args.
x(0), args.
x(1), zero, args.
dy(0));
3251 w <<
"if (" << args.
x(0) <<
">=" << args.
x(1) <<
") ";
3252 args.
y(0) = args.
x(2);
3254 args.
y(0) = args.
x(3);
3258 w <<
"if (" << args.
x(0) <<
">=" << args.
x(1) <<
") ";
3259 args.
dx(2) += args.
dy(0);
3261 args.
dx(3) += args.
dy(0);
3263 const char *CondExpGeOp::op_name() {
3267 Scalar CondExpGe(
const Scalar &x0,
const Scalar &x1,
const Scalar &x2,
3274 ad_plain CondExpGe(
const ad_plain &x0,
const ad_plain &x1,
const ad_plain &x2,
3275 const ad_plain &x3) {
3277 std::vector<ad_plain> x(4);
3293 return CondExpGe(ad_plain(x0), ad_plain(x1), ad_plain(x2), ad_plain(x3));
3297 if (args.
x(0) <= args.
x(1)) {
3298 args.
y(0) = args.
x(2);
3300 args.
y(0) = args.
x(3);
3304 if (args.
x(0) <= args.
x(1)) {
3305 args.
dx(2) += args.
dy(0);
3307 args.
dx(3) += args.
dy(0);
3311 args.
y(0) = CondExpLe(args.
x(0), args.
x(1), args.
x(2), args.
x(3));
3315 args.
dx(2) += CondExpLe(args.
x(0), args.
x(1), args.
dy(0), zero);
3316 args.
dx(3) += CondExpLe(args.
x(0), args.
x(1), zero, args.
dy(0));
3320 w <<
"if (" << args.
x(0) <<
"<=" << args.
x(1) <<
") ";
3321 args.
y(0) = args.
x(2);
3323 args.
y(0) = args.
x(3);
3327 w <<
"if (" << args.
x(0) <<
"<=" << args.
x(1) <<
") ";
3328 args.
dx(2) += args.
dy(0);
3330 args.
dx(3) += args.
dy(0);
3332 const char *CondExpLeOp::op_name() {
3336 Scalar CondExpLe(
const Scalar &x0,
const Scalar &x1,
const Scalar &x2,
3343 ad_plain CondExpLe(
const ad_plain &x0,
const ad_plain &x1,
const ad_plain &x2,
3344 const ad_plain &x3) {
3346 std::vector<ad_plain> x(4);
3362 return CondExpLe(ad_plain(x0), ad_plain(x1), ad_plain(x2), ad_plain(x3));
3366 Index SumOp::input_size()
const {
return n; }
3368 Index SumOp::output_size()
const {
return 1; }
3370 SumOp::SumOp(
size_t n) : n(n) {}
3372 const char *SumOp::op_name() {
return "SumOp"; }
3374 Index LogSpaceSumOp::input_size()
const {
return this->n; }
3376 Index LogSpaceSumOp::output_size()
const {
return 1; }
3378 LogSpaceSumOp::LogSpaceSumOp(
size_t n) : n(n) {}
3381 Scalar Max = -INFINITY;
3382 for (
size_t i = 0; i < n; i++) {
3383 if (Max < args.
x(i)) Max = args.
x(i);
3386 for (
size_t i = 0; i < n; i++) {
3387 args.
y(0) += exp(args.
x(i) - Max);
3389 args.
y(0) = Max + log(args.
y(0));
3393 std::vector<ad_plain> x(input_size());
3394 for (Index i = 0; i < input_size(); i++) x[i] = args.
x(i);
3395 args.
y(0) = logspace_sum(x);
3398 const char *LogSpaceSumOp::op_name() {
return "LSSumOp"; }
3400 ad_plain logspace_sum(
const std::vector<ad_plain> &x) {
3405 Index LogSpaceSumStrideOp::number_of_terms()
const {
return stride.size(); }
3407 Index LogSpaceSumStrideOp::input_size()
const {
return number_of_terms(); }
3409 Index LogSpaceSumStrideOp::output_size()
const {
return 1; }
3411 LogSpaceSumStrideOp::LogSpaceSumStrideOp(std::vector<Index> stride,
size_t n)
3412 : stride(stride), n(n) {}
3415 Scalar Max = -INFINITY;
3417 size_t m = stride.size();
3418 std::vector<Scalar *> wrk(m);
3419 Scalar **px = &(wrk[0]);
3420 for (
size_t i = 0; i < m; i++) {
3421 px[i] = args.
x_ptr(i);
3424 for (
size_t i = 0; i < n; i++) {
3425 Scalar s = rowsum(px, i);
3426 if (Max < s) Max = s;
3430 for (
size_t i = 0; i < n; i++) {
3431 Scalar s = rowsum(px, i);
3432 args.
y(0) += exp(s - Max);
3434 args.
y(0) = Max + log(args.
y(0));
3438 std::vector<ad_plain> x(input_size());
3439 for (Index i = 0; i < input_size(); i++) x[i] = args.
x(i);
3440 args.
y(0) = logspace_sum_stride(x, stride, n);
3443 void LogSpaceSumStrideOp::dependencies(
Args<> &args, Dependencies &dep)
const {
3444 for (
size_t j = 0; j < (size_t)number_of_terms(); j++) {
3445 size_t K = n * stride[j];
3446 dep.add_segment(args.
input(j), K);
3450 const char *LogSpaceSumStrideOp::op_name() {
return "LSStride"; }
3453 TMBAD_ASSERT(
false);
3457 TMBAD_ASSERT(
false);
3460 ad_plain logspace_sum_stride(
const std::vector<ad_plain> &x,
3461 const std::vector<Index> &stride,
size_t n) {
3462 TMBAD_ASSERT(x.size() == stride.size());
3468 #include "graph2dot.hpp" 3471 void graph2dot(
global glob,
graph G,
bool show_id, std::ostream &cout) {
3472 cout <<
"digraph graphname {\n";
3473 for (
size_t i = 0; i < glob.
opstack.size(); i++) {
3475 cout << i <<
" [label=\"" << glob.
opstack[i]->op_name() <<
"\"];\n";
3477 cout << i <<
" [label=\"" << glob.
opstack[i]->op_name() <<
" " << i
3480 for (
size_t node = 0; node < G.num_nodes(); node++) {
3481 for (
size_t k = 0; k < G.num_neighbors(node); k++) {
3482 cout << node <<
" -> " << G.neighbors(node)[k] <<
";\n";
3485 for (
size_t i = 0; i < glob.subgraph_seq.size(); i++) {
3486 size_t node = glob.subgraph_seq[i];
3487 cout << node <<
" [style=\"filled\"];\n";
3490 std::vector<Index> v2o = glob.
var2op();
3492 cout <<
"{rank=same;";
3493 for (
size_t i = 0; i < glob.
inv_index.size(); i++) {
3498 cout <<
"{rank=same;";
3499 for (
size_t i = 0; i < glob.
dep_index.size(); i++) {
3507 void graph2dot(
global glob,
bool show_id, std::ostream &cout) {
3509 graph2dot(glob, G, show_id, cout);
3512 void graph2dot(
const char *filename,
global glob,
graph G,
bool show_id) {
3513 std::ofstream myfile;
3514 myfile.open(filename);
3515 graph2dot(glob, G, show_id, myfile);
3519 void graph2dot(
const char *filename,
global glob,
bool show_id) {
3520 std::ofstream myfile;
3521 myfile.open(filename);
3522 graph2dot(glob, show_id, myfile);
3527 #include "graph_transform.hpp" 3530 std::vector<size_t>
which(
const std::vector<bool> &x) {
3531 return which<size_t>(x);
3536 for (
size_t i = 0; i < x.size(); i++) ans *= x[i];
3541 const std::vector<bool> &vars) {
3542 std::vector<bool> boundary(vars);
3543 std::vector<bool> node_filter = glob.
var2op(vars);
3546 for (
size_t i = 0; i < vars.size(); i++) boundary[i] = boundary[i] ^ vars[i];
3553 std::vector<bool> node_subset(opstack.size(),
false);
3554 for (
size_t i = 0; i < opstack.size(); i++) {
3560 std::vector<bool> var_subset = glob.
op2var(node_subset);
3568 node_subset = glob.
var2op(var_subset);
3570 return which<Index>(node_subset);
3574 std::vector<Index> ans;
3576 for (
size_t i = 0; i < opstack.size(); i++) {
3577 if (!strcmp(opstack[i]->op_name(), name)) {
3585 bool inv_tags,
bool dep_tags) {
3587 std::vector<Index> seq2(seq);
3589 OperatorPure *invop = glob.getOperator<global::InvOp>();
3590 for (
size_t i = 0; i < seq2.size(); i++) {
3592 if (inv_tags) TMBAD_ASSERT(op != invop);
3600 std::vector<Index> new_inv = glob.
op2var(seq2);
3601 if (!inv_tags) glob.
inv_index.resize(0);
3602 if (!dep_tags) glob.
dep_index.resize(0);
3610 return substitute(glob, seq, inv_tags, dep_tags);
3618 substitute(glob_tree, boundary,
false,
true);
3623 std::vector<Scalar> x0(n);
3624 for (
size_t i = 0; i < n; i++) x0[i] = glob_tree.
value_inv(i);
3630 std::vector<Scalar> J(n);
3631 for (
size_t i = 0; i < n; i++) J[i] = glob_tree.
deriv_inv(i);
3633 for (
size_t i = 0; i < n; i++) V -= J[i] * x0[i];
3635 std::vector<Index> vars = glob.
op2var(boundary);
3638 std::vector<ad_aug_index> res(vars.begin(), vars.end());
3639 for (
size_t i = 0; i < vars.size(); i++) {
3640 res[i] = res[i] * J[i];
3641 if (i == 0) res[i] += V;
3642 if (!sum_) res[i].Dependent();
3654 TMBAD_ASSERT((sign == 1) || (sign == -1));
3658 for (
size_t i = 0; i < x.size(); i++) y += x[i];
3659 if (sign < 0) y = -y;
3667 opstack_size = glob.
opstack.size();
3670 void old_state::restore() {
3672 while (glob.
opstack.size() > opstack_size) {
3673 Index input_size = glob.
opstack.back()->input_size();
3674 Index output_size = glob.
opstack.back()->output_size();
3677 glob.
opstack.back()->deallocate();
3682 term_info::term_info(
global &glob,
bool do_init) :
glob(glob) {
3683 if (do_init) initialize();
3686 void term_info::initialize(std::vector<Index> inv_remap) {
3687 if (inv_remap.size() == 0) inv_remap.resize(glob.
inv_index.size(), 0);
3688 inv_remap = radix::factor<Index>(inv_remap);
3691 id = radix::factor<Index>(term_ids);
3692 Index max_id = *std::max_element(
id.
begin(),
id.
end());
3693 count.resize(max_id + 1, 0);
3694 for (
size_t i = 0; i <
id.size(); i++) {
3699 gk_config::gk_config()
3700 : debug(
false),
adaptive(
false), nan2zero(
true), ytol(1e-2), dx(1) {}
3704 for (
size_t i = 0; i < bound.size(); i++)
3705 if (mask_[i]) count *= bound[i];
3711 bound.resize(dim, bound_);
3713 mask_.resize(dim, flag);
3717 : pointer(0), bound(bound) {
3718 x.resize(bound.size(), 0);
3719 mask_.resize(bound.size(), flag);
3726 for (
size_t i = 0; i < x.size(); i++) {
3728 if (x[i] < bound[i] - 1) {
3734 pointer -= (bound[i] - 1) * N;
3742 multivariate_index::operator size_t() {
return pointer; }
3753 TMBAD_ASSERT(mask.size() == mask_.size());
3757 size_t clique::clique_size() {
return indices.size(); }
3761 void clique::subset_inplace(
const std::vector<bool> &
mask) {
3762 indices =
subset(indices, mask);
3766 void clique::logsum_init() { logsum.resize(
prod_int(dim)); }
3768 bool clique::empty()
const {
return (indices.size() == 0); }
3770 bool clique::contains(Index i) {
3772 for (
size_t j = 0; j < indices.size(); j++) ans |= (i == indices[j]);
3777 std::vector<ad_plain> &offset, Index &stride) {
3779 for (
size_t k = 0; (k < clique_size()) && (indices[k] < ind); k++) {
3784 size_t nx = mv.
count();
3785 std::vector<bool> mask =
lmatch(super.
indices, this->indices);
3788 std::vector<ad_plain> x(nx);
3789 size_t xa_count = mv.
count();
3791 size_t xi_count = mv.
count();
3793 TMBAD_ASSERT(x.size() == xa_count * xi_count);
3794 for (
size_t i = 0; i < xa_count; i++, ++mv) {
3796 for (
size_t j = 0; j < xi_count; j++, ++mv) {
3797 TMBAD_ASSERT(logsum[j].on_some_tape());
3808 xa_count = mv.
count();
3809 offset.resize(xa_count);
3810 for (
size_t i = 0; i < xa_count; i++, ++mv) {
3815 sr_grid::sr_grid() {}
3817 sr_grid::sr_grid(Scalar a, Scalar b,
size_t n) : x(n), w(n) {
3818 Scalar h = (b - a) / n;
3819 for (
size_t i = 0; i < n; i++) {
3820 x[i] = a + h / 2 + i * h;
3825 sr_grid::sr_grid(
size_t n) {
3826 for (
size_t i = 0; i < n; i++) {
3828 w[i] = 1. / (double)n;
3832 size_t sr_grid::size() {
return x.size(); }
3834 ad_plain sr_grid::logw_offset() {
3835 if (logw.size() != w.size()) {
3836 logw.resize(w.size());
3837 for (
size_t i = 0; i < w.size(); i++) logw[i] = log(w[i]);
3844 std::vector<Index> random,
3845 std::vector<sr_grid> grid,
3846 std::vector<Index> random2grid,
3851 replay(glob, new_glob),
3852 tinfo(glob, false) {
3853 inv2grid.resize(glob.
inv_index.size(), 0);
3854 for (
size_t i = 0; i < random2grid.size(); i++) {
3855 inv2grid[random[i]] = random2grid[i];
3858 mark.resize(glob.
values.size(),
false);
3859 for (
size_t i = 0; i < random.size(); i++)
3868 var_remap.resize(glob.
values.size());
3875 terms_done.resize(glob.
dep_index.size(),
false);
3877 std::vector<Index> inv_remap(glob.
inv_index.size());
3878 for (
size_t i = 0; i < inv_remap.size(); i++) inv_remap[i] = -(i + 1);
3879 for (
size_t i = 0; i < random.size(); i++)
3880 inv_remap[random[i]] = inv2grid[random[i]];
3881 inv_remap = radix::factor<Index>(inv_remap);
3882 tinfo.initialize(inv_remap);
3886 std::vector<IndexPair> edges;
3887 std::vector<Index> &inv2op = forward_graph.
inv2op;
3889 for (
size_t i = 0; i < random.size(); i++) {
3890 std::vector<Index> subgraph(1, inv2op[random[i]]);
3891 forward_graph.
search(subgraph);
3892 reverse_graph.
search(subgraph);
3893 for (
size_t l = 0; l < subgraph.size(); l++) {
3894 Index inv_other = op2inv_idx[subgraph[l]];
3895 if (inv_other != NA) {
3896 IndexPair edge(random[i], inv_other);
3897 edges.push_back(edge);
3902 size_t num_nodes = glob.
inv_index.size();
3903 graph G(num_nodes, edges);
3905 std::vector<bool> visited(num_nodes,
false);
3906 std::vector<Index> subgraph;
3907 for (
size_t i = 0; i < random.size(); i++) {
3908 if (visited[random[i]])
continue;
3909 std::vector<Index> sg(1, random[i]);
3910 G.
search(sg, visited,
false,
false);
3911 subgraph.insert(subgraph.end(), sg.begin(), sg.end());
3913 std::reverse(subgraph.begin(), subgraph.end());
3914 TMBAD_ASSERT(random.size() == subgraph.size());
3918 std::vector<size_t> sequential_reduction::get_grid_bounds(
3919 std::vector<Index> inv_index) {
3920 std::vector<size_t> ans(inv_index.size());
3921 for (
size_t i = 0; i < inv_index.size(); i++) {
3922 ans[i] = grid[inv2grid[inv_index[i]]].size();
3927 std::vector<sr_grid *> sequential_reduction::get_grid(
3928 std::vector<Index> inv_index) {
3929 std::vector<sr_grid *> ans(inv_index.size());
3930 for (
size_t i = 0; i < inv_index.size(); i++) {
3931 ans[i] = &(grid[inv2grid[inv_index[i]]]);
3938 size_t id = tinfo.id[dep_index];
3939 size_t count = tinfo.count[id];
3940 bool do_cache = (count >= 2);
3942 if (cache[
id].size() > 0) {
3947 std::vector<sr_grid *> inv_grid = get_grid(inv_index);
3948 std::vector<size_t> grid_bounds = get_grid_bounds(inv_index);
3950 std::vector<ad_aug> ans(mv.
count());
3951 for (
size_t i = 0; i < ans.size(); i++, ++mv) {
3952 for (
size_t j = 0; j < inv_index.size(); j++) {
3953 replay.value_inv(inv_index[j]) = inv_grid[j]->x[mv.
index(j)];
3955 replay.forward_sub();
3956 ans[i] = replay.value_dep(dep_index);
3967 std::vector<Index> super;
3969 for (std::list<clique>::iterator it = cliques.begin(); it != cliques.end();
3971 if ((*it).contains(i)) {
3972 super.insert(super.end(), (*it).indices.begin(), (*it).indices.end());
3978 std::vector<std::vector<ad_plain> > offset_by_clique(c);
3979 std::vector<Index> stride_by_clique(c);
3982 C.
dim = get_grid_bounds(super);
3983 std::list<clique>::iterator it = cliques.begin();
3985 while (it != cliques.end()) {
3986 if ((*it).contains(i)) {
3987 (*it).get_stride(C, i, offset_by_clique[c], stride_by_clique[c]);
3988 it = cliques.erase(it);
3995 std::vector<bool> mask =
lmatch(super, std::vector<Index>(1, i));
3997 C.subset_inplace(mask);
4000 grid[inv2grid[i]].logw_offset();
4002 for (
size_t j = 0; j < C.
logsum.size(); j++) {
4003 std::vector<ad_plain> x;
4004 std::vector<Index> stride;
4005 for (
size_t k = 0; k < offset_by_clique.size(); k++) {
4006 x.push_back(offset_by_clique[k][j]);
4007 stride.push_back(stride_by_clique[k]);
4010 x.push_back(grid[inv2grid[i]].logw_offset());
4011 stride.push_back(1);
4012 C.
logsum[j] = logspace_sum_stride(x, stride, grid[inv2grid[i]].size());
4015 TMBAD_ASSERT(v_end - v_begin == C.
logsum.size());
4017 cliques.push_back(C);
4021 const std::vector<Index> &inv2op = forward_graph.
inv2op;
4023 Index start_node = inv2op[i];
4024 std::vector<Index> subgraph(1, start_node);
4025 forward_graph.
search(subgraph);
4027 std::vector<Index> dep_clique;
4028 std::vector<Index> subgraph_terms;
4029 for (
size_t k = 0; k < subgraph.size(); k++) {
4030 Index node = subgraph[k];
4031 Index dep_idx = op2dep_idx[node];
4032 if (dep_idx != NA && !terms_done[dep_idx]) {
4033 terms_done[dep_idx] =
true;
4034 subgraph_terms.push_back(node);
4035 dep_clique.push_back(dep_idx);
4038 for (
size_t k = 0; k < subgraph_terms.size(); k++) {
4040 subgraph.push_back(subgraph_terms[k]);
4042 reverse_graph.
search(subgraph);
4044 std::vector<Index> inv_clique;
4045 for (
size_t l = 0; l < subgraph.size(); l++) {
4046 Index tmp = op2inv_idx[subgraph[l]];
4047 if (tmp != NA) inv_clique.push_back(tmp);
4050 glob.subgraph_seq = subgraph;
4054 C.
dim = get_grid_bounds(inv_clique);
4057 cliques.push_back(C);
4063 void sequential_reduction::show_cliques() {
4064 Rcout <<
"Cliques: ";
4065 std::list<clique>::iterator it;
4066 for (it = cliques.begin(); it != cliques.end(); ++it) {
4067 Rcout << it->indices <<
" ";
4072 void sequential_reduction::update_all() {
4073 for (
size_t i = 0; i < random.size(); i++)
update(random[i]);
4076 ad_aug sequential_reduction::get_result() {
4078 std::list<clique>::iterator it;
4079 for (it = cliques.begin(); it != cliques.end(); ++it) {
4080 TMBAD_ASSERT(it->clique_size() == 0);
4081 TMBAD_ASSERT(it->logsum.size() == 1);
4082 ans += it->logsum[0];
4085 for (
size_t i = 0; i < terms_done.size(); i++) {
4086 if (!terms_done[i]) ans += replay.value_dep(i);
4091 global sequential_reduction::marginal() {
4093 replay.forward(
true,
false);
4095 ad_aug ans = get_result();
4101 autopar::autopar(
global &glob,
size_t num_threads)
4103 num_threads(num_threads),
4104 do_aggregate(
false),
4105 keep_all_inv(
false) {
4110 std::vector<Index> max_tree_depth(glob.
opstack.size(), 0);
4113 for (
size_t i = 0; i < glob.
opstack.size(); i++) {
4115 glob.
opstack[i]->dependencies(args, dep);
4116 for (
size_t j = 0; j < dep.size(); j++) {
4117 max_tree_depth[i] =
std::max(max_tree_depth[i], max_tree_depth[dep[j]]);
4120 max_tree_depth[i]++;
4124 std::vector<size_t> ans(glob.
dep_index.size());
4125 for (
size_t j = 0; j < glob.
dep_index.size(); j++) {
4126 ans[j] = max_tree_depth[glob.
dep_index[j]];
4131 void autopar::run() {
4132 std::vector<size_t> ord =
order(max_tree_depth());
4133 std::reverse(ord.begin(), ord.end());
4134 std::vector<bool> visited(glob.
opstack.size(),
false);
4135 std::vector<Index> start;
4136 std::vector<Index> dWork(ord.size());
4137 for (
size_t i = 0; i < ord.size(); i++) {
4139 start[0] = reverse_graph.
dep2op[ord[i]];
4140 reverse_graph.
search(start, visited,
false,
false);
4141 dWork[i] = start.size();
4143 for (
size_t k = 0; k < start.size(); k++) {
4144 Rcout << glob.
opstack[start[k]]->op_name() <<
" ";
4150 std::vector<size_t> thread_assign(ord.size(), 0);
4151 std::vector<size_t> work_by_thread(num_threads, 0);
4152 for (
size_t i = 0; i < dWork.size(); i++) {
4154 thread_assign[i] = 0;
4157 thread_assign[i] = thread_assign[i - 1];
4159 thread_assign[i] = which_min(work_by_thread);
4161 work_by_thread[thread_assign[i]] += dWork[i];
4164 node_split.resize(num_threads);
4165 for (
size_t i = 0; i < ord.size(); i++) {
4166 node_split[thread_assign[i]].push_back(reverse_graph.
dep2op[ord[i]]);
4169 for (
size_t i = 0; i < num_threads; i++) {
4171 node_split[i].insert(node_split[i].begin(), reverse_graph.
inv2op.begin(),
4172 reverse_graph.
inv2op.end());
4173 reverse_graph.
search(node_split[i]);
4178 vglob.resize(num_threads);
4179 inv_idx.resize(num_threads);
4180 dep_idx.resize(num_threads);
4181 std::vector<Index> tmp;
4182 for (
size_t i = 0; i < num_threads; i++) {
4183 glob.subgraph_seq = node_split[i];
4191 for (
size_t i = 0; i < num_threads; i++) {
4192 std::vector<Index> &seq = node_split[i];
4193 for (
size_t j = 0; j < seq.size(); j++) {
4194 if (op2inv_idx[seq[j]] != NA) inv_idx[i].push_back(op2inv_idx[seq[j]]);
4195 if (op2dep_idx[seq[j]] != NA) dep_idx[i].push_back(op2dep_idx[seq[j]]);
4198 dep_idx[i].resize(1);
4207 return (do_aggregate ? num_threads : glob.
dep_index.size());
4210 Index ParalOp::input_size()
const {
return n; }
4212 Index ParalOp::output_size()
const {
return m; }
4214 ParalOp::ParalOp(
const autopar &ap)
4222 size_t num_threads = vglob.size();
4225 #pragma omp parallel for 4228 for (
size_t i = 0; i < num_threads; i++) {
4229 for (
size_t j = 0; j < inv_idx[i].size(); j++) {
4230 vglob[i].value_inv(j) = args.
x(inv_idx[i][j]);
4235 for (
size_t i = 0; i < num_threads; i++) {
4236 for (
size_t j = 0; j < dep_idx[i].size(); j++) {
4237 args.
y(dep_idx[i][j]) = vglob[i].value_dep(j);
4243 size_t num_threads = vglob.size();
4246 #pragma omp parallel for 4249 for (
size_t i = 0; i < num_threads; i++) {
4250 vglob[i].clear_deriv();
4251 for (
size_t j = 0; j < dep_idx[i].size(); j++) {
4252 vglob[i].deriv_dep(j) = args.
dy(dep_idx[i][j]);
4257 for (
size_t i = 0; i < num_threads; i++) {
4258 for (
size_t j = 0; j < inv_idx[i].size(); j++) {
4259 args.
dx(inv_idx[i][j]) += vglob[i].deriv_inv(j);
4264 const char *ParalOp::op_name() {
return "ParalOp"; }
4267 size_t num_threads = vglob.size();
4268 for (
size_t i = 0; i < num_threads; i++) {
4270 std::stringstream ss;
4272 std::string str = ss.str();
4273 cfg2.prefix = cfg2.prefix + str;
4274 vglob[i].print(cfg2);
4278 std::vector<Index> get_likely_expression_duplicates(
4279 const global &glob, std::vector<Index> inv_remap) {
4287 std::vector<hash_t> h = glob.
hash_sweep(cfg);
4288 return radix::first_occurance<Index>(h);
4293 for (
size_t i = 0; i < glob.
opstack.size(); i++) {
4304 global &glob, std::vector<Index> inv_remap) {
4305 std::vector<Index> remap = get_likely_expression_duplicates(glob, inv_remap);
4307 for (
size_t i = 0; i < glob.
inv_index.size(); i++) {
4308 bool accept =
false;
4310 if (inv_remap.size() > 0) {
4311 Index j = inv_remap[i];
4313 accept = remap[var_i] == remap[var_j];
4315 if (!accept) remap[var_i] = var_i;
4318 std::vector<Index> v2o = glob.
var2op();
4319 std::vector<Index> dep;
4327 for (
size_t j = 0, i = 0, nout = 0; j < glob.
opstack.size(); j++, i += nout) {
4328 nout = glob.
opstack[j]->output_size();
4329 bool any_remap =
false;
4330 for (
size_t k = i; k < i + nout; k++) {
4331 if (remap[k] != k) {
4349 if (ok && (nout > 1)) {
4350 for (
size_t k = 1; k < nout; k++) {
4351 ok &= (remap[i + k] < i);
4353 ok &= (v2o[remap[i + k]] == v2o[remap[i]]);
4355 ok &= (remap[i + k] == remap[i] + k);
4359 if (CurOp == invop) {
4373 args.
ptr = glob.subgraph_ptr[v2o[i]];
4375 glob.
opstack[v2o[i]]->dependencies(args, dep1);
4377 args.
ptr = glob.subgraph_ptr[v2o[remap[i]]];
4379 glob.
opstack[v2o[remap[i]]]->dependencies(args, dep2);
4381 ok = (dep1.size() == dep2.size());
4383 bool all_equal =
true;
4384 for (
size_t j = 0; j < dep1.size(); j++) {
4385 all_equal &= (remap[dep1[j]] == remap[dep2[j]]);
4393 for (
size_t k = i; k < i + nout; k++) remap[k] = k;
4398 for (
size_t i = 0; i < remap.size(); i++) {
4399 TMBAD_ASSERT(remap[i] <= i);
4400 TMBAD_ASSERT(remap[remap[i]] == remap[i]);
4406 for (
size_t i = 0; i < glob.
opstack.size(); i++) {
4410 glob.
opstack[i]->dependencies(args, dep);
4411 for (
size_t j = 0; j < dep.I.size(); j++) {
4412 visited.
insert(dep.I[j].first, dep.I[j].second);
4429 std::vector<Index> inv_remap(0);
4432 for (
size_t i = 0; i < glob.
inputs.size(); i++) {
4437 std::vector<Position> inv_positions(
global &glob) {
4438 IndexPair ptr(0, 0);
4439 std::vector<bool> independent_variable = glob.
inv_marks();
4440 std::vector<Position> ans(glob.
inv_index.size());
4442 for (
size_t i = 0; i < glob.
opstack.size(); i++) {
4443 Index nout = glob.
opstack[i]->output_size();
4444 for (Index j = 0; j < nout; j++) {
4445 if (independent_variable[ptr.second + j]) {
4451 glob.
opstack[i]->increment(ptr);
4458 for (
size_t i = 1; i < inv_idx.size(); i++) {
4459 TMBAD_ASSERT(inv_idx[i] > inv_idx[i - 1]);
4461 std::vector<bool> marks(glob.
values.size(),
false);
4462 for (
size_t i = 0; i < inv_idx.size(); i++)
4463 marks[glob.
inv_index[inv_idx[i]]] =
true;
4466 int c = std::count(marks.begin(), marks.end(),
true);
4467 Rcout <<
"marked proportion:" << (double)c / (
double)marks.size() <<
"\n";
4478 #include "integrate.hpp" 4481 double value(
double x) {
return x; }
4483 control::control(
int subdivisions_,
double reltol_,
double abstol_)
4484 : subdivisions(subdivisions_), reltol(reltol_), abstol(abstol_) {}
4487 #include "radix.hpp" 4490 #include "tmbad_allow_comparison.hpp" 4496 bool operator<(
const Scalar &x,
const ad_aug &y) {
return x < y.
Value(); }
4501 bool operator<=(
const Scalar &x,
const ad_aug &y) {
return x <= y.
Value(); }
4506 bool operator>(
const Scalar &x,
const ad_aug &y) {
return x > y.
Value(); }
4511 bool operator>=(
const Scalar &x,
const ad_aug &y) {
return x >= y.
Value(); }
4516 bool operator==(
const Scalar &x,
const ad_aug &y) {
return x == y.
Value(); }
4521 bool operator!=(
const Scalar &x,
const ad_aug &y) {
return x != y.
Value(); }
4527 VSumOp::VSumOp(
size_t n) : n(n) {}
4529 void VSumOp::dependencies(
Args<> &args, Dependencies &dep)
const {
4530 dep.add_segment(args.
input(0), n);
4537 const char *VSumOp::op_name() {
return "VSumOp"; }
4544 Scalar *SegmentRef::value_ptr() {
return (*glob_ptr).values.data() + offset; }
4546 Scalar *SegmentRef::deriv_ptr() {
return (*glob_ptr).derivs.data() + offset; }
4548 SegmentRef::SegmentRef() {}
4550 SegmentRef::SegmentRef(
const Scalar *x) {
4555 SegmentRef::SegmentRef(
global *g, Index o, Index s)
4556 : glob_ptr(g), offset(o), size(s) {}
4558 SegmentRef::SegmentRef(
const ad_segment &x) {
4559 static const size_t K = ScalarPack<SegmentRef>::size;
4560 TMBAD_ASSERT(x.size() == K);
4562 for (
size_t i = 0; i < K; i++) buf[i] = x[i].Value();
4567 bool SegmentRef::isNull() {
return (glob_ptr == NULL); }
4570 Index i = pack.index();
4575 PackOp::PackOp(
const Index n) : n(n) {}
4589 if (tmp.glob_ptr != NULL) {
4592 for (Index i = 0; i < n; i++) dx[i] += dy[i];
4605 Replay *pdx = args.
dx_ptr(0);
4606 for (Index i = 0; i < n; i++) pdx[i] = dx[i];
4609 const char *PackOp::op_name() {
return "PackOp"; }
4611 void PackOp::dependencies(
Args<> &args, Dependencies &dep)
const {
4612 dep.add_segment(args.
input(0), n);
4615 UnpkOp::UnpkOp(
const Index n) : noutput(n) {}
4618 Scalar *y = args.
y_ptr(0);
4621 for (Index i = 0; i < noutput; i++) y[i] = 0;
4624 Scalar *x = srx.value_ptr();
4625 for (Index i = 0; i < noutput; i++) y[i] = x[i];
4638 Replay *pdx = args.
dx_ptr(0);
4639 for (Index i = 0; i < dy_packed.size(); i++) pdx[i] = dy_packed[i];
4642 const char *UnpkOp::op_name() {
return "UnpkOp"; }
4644 void UnpkOp::dependencies(
Args<> &args, Dependencies &dep)
const {
4645 dep.add_segment(args.
input(0), K);
4659 Scalar *
unpack(
const std::vector<Scalar> &x, Index j) {
4660 Index K = ScalarPack<SegmentRef>::size;
4662 return sr.value_ptr();
4665 std::vector<ad_aug> concat(
const std::vector<ad_segment> &x) {
4666 std::vector<ad_aug> ans;
4667 for (
size_t i = 0; i < x.size(); i++) {
4669 for (
size_t j = 0; j < xi.size(); j++) {
4670 ans.push_back(xi[j]);
Automatic differentiation library designed for TMB.
std::vector< Index > op2var(const std::vector< Index > &seq)
Get variables produces by a node seqence.
std::vector< T > subset(const std::vector< T > &x, const std::vector< bool > &y)
Vector subset by boolean mask.
graph reverse_graph(std::vector< bool > keep_var=std::vector< bool >(0))
Construct operator graph with reverse connections.
diff --git a/TMBad_8hpp_source.html b/TMBad_8hpp_source.html
index 563fb39c9..9209ea2b1 100644
--- a/TMBad_8hpp_source.html
+++ b/TMBad_8hpp_source.html
@@ -73,7 +73,7 @@
TMBad.hpp