Skip to content

Commit

Permalink
fix bug where non-floating-point answer types were overwritten by def…
Browse files Browse the repository at this point in the history
…ault traces template args
  • Loading branch information
fymue committed Jul 22, 2023
1 parent 3377b11 commit 54d3e10
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 51 deletions.
18 changes: 9 additions & 9 deletions rtlib/traces.hh
Original file line number Diff line number Diff line change
Expand Up @@ -75,21 +75,21 @@ inline bool is_same_index(const index_components &lhs_idx,
/* this records the use of a sub-solution of another NT result <name of other
* NT, index in other NT table with variable number of dimensions, weight
* of edge> */
template<typename answer = double>
template<typename answer>
using Trace = std::tuple<std::string, index_components, answer>;
/* Since a candidate can be composed of multiple sub-solutions e.g.
* mult(A, '*', B), we collect Traces for A and B */
template<typename answer = double>
template<typename answer>
using NTtrace = std::tuple<answer, std::vector<Trace<answer>>>;
/* collection of all traces of an NT in forward pass, together with their
* values. This later normalized into edge weights. */

// type for "traces" vector in NT tables
template<typename answer = double>
template<typename answer>
using Traces = std::vector<Trace<answer>>;

// answer type can be any primitve number type or "Batch"
template<typename answer = double>
template<typename answer>
class candidate {
private:
answer value;
Expand Down Expand Up @@ -195,12 +195,12 @@ class candidate {
// }
};

template<typename answer = double>
template<typename answer>
using NTtraces = std::vector<candidate<answer>>;

// once all use of sub-solutions for candidates is finished, we need to
// normalize their contributions
template<typename answer = double>
template<typename answer>
void normalize_traces(std::vector<Trace<answer>> *tabulated,
const std::vector<candidate<answer>> &candidates,
answer eval,
Expand All @@ -214,7 +214,7 @@ void normalize_traces(std::vector<Trace<answer>> *tabulated,

// overload of normalize_traces taking a function ptr argument
// with two const reference parameters
template<typename answer = double>
template<typename answer>
void normalize_traces(std::vector<Trace<answer>> *tabulated,
const std::vector<candidate<answer>> &candidates,
answer eval,
Expand All @@ -226,7 +226,7 @@ void normalize_traces(std::vector<Trace<answer>> *tabulated,
}
}

template<typename answer = double>
template<typename answer>
void soft_max_hessian_product(std::vector<Trace<answer>> *tabulated,
const NTtraces<answer> &candidates, answer eval) {
for (size_t i = 0; i < candidates.size(); ++i) {
Expand All @@ -236,7 +236,7 @@ void soft_max_hessian_product(std::vector<Trace<answer>> *tabulated,
}
}

template<typename answer = double>
template<typename answer>
inline answer get_trace_weights(const std::vector<Trace<answer>> &traces,
const std::string &to_nt,
const index_components &to_indices,
Expand Down
49 changes: 33 additions & 16 deletions src/alt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1768,30 +1768,38 @@ void Alt::Simple::init_body(AST &ast, Symbol::NT &calling_nt) {
}
}

std::list<Statement::Base*> *Alt::Base::derivatives_create_candidate() {
return NULL;
std::pair<std::list<Statement::Base*>*, std::string>
Alt::Base::derivatives_create_candidate() {
return std::make_pair(nullptr, "");
}

std::list<Statement::Base*> *Alt::Simple::derivatives_create_candidate() {
std::pair<std::list<Statement::Base*>*, std::string>
Alt::Simple::derivatives_create_candidate() {
std::list<Statement::Base*> *stmts_record = \
new std::list<Statement::Base*>();
std::string nt_name;

for (std::list<Fn_Arg::Base*>::iterator i = args.begin();
i != args.end(); ++i) {
if ((*i)->is(Fn_Arg::ALT) && ((*i)->alt_ref()->is(Alt::LINK))) {
Alt::Link *alt = dynamic_cast<Alt::Link*>((*i)->alt_ref());
if (alt->nt->is(Symbol::NONTERMINAL)) {
std::list<Statement::Base*> *x = alt->derivatives_create_candidate();
stmts_record-> insert(stmts_record->end(), x->begin(), x->end());
std::pair<std::list<Statement::Base*>*, std::string>
x = alt->derivatives_create_candidate();
stmts_record->insert(stmts_record->end(),
x.first->begin(), x.first->end());
nt_name = std::move(x.second);
}
}
}

return stmts_record;
return std::make_pair(stmts_record, std::move(nt_name));
}

std::list<Statement::Base*> *Alt::Link::derivatives_create_candidate() {
std::pair<std::list<Statement::Base*>*, std::string>
Alt::Link::derivatives_create_candidate() {
std::list<Statement::Base*> *stmts_record = new std::list<Statement::Base*>();
std::string nt_name;

Expr::Fn_Call *mkidx = new Expr::Fn_Call(
new std::string("index_components"));
Expand All @@ -1806,35 +1814,43 @@ std::list<Statement::Base*> *Alt::Link::derivatives_create_candidate() {

stmts_record->push_back(fn_add);

return stmts_record;
return std::make_pair(stmts_record, *this->nt->name);
}

std::list<Statement::Base*> *Alt::Block::derivatives_create_candidate() {
std::pair<std::list<Statement::Base*>*, std::string>
Alt::Block::derivatives_create_candidate() {
throw LogError(
"Alt::Block::derivatives_create_candidate not properly implemented yet, "
"as Blocks without application of choice function open up the route for"
" huge combinatorics!");

std::list<Statement::Base*> *stmts_record = new std::list<Statement::Base*>();
std::string nt_name;
for (std::list<Alt::Base*>::iterator i = alts.begin(); i != alts.end(); ++i) {
std::list<Statement::Base*> *x = (*i)->derivatives_create_candidate();
stmts_record->insert(stmts_record->end(), x->begin(), x->end());
std::pair<std::list<Statement::Base*>*, std::string>
x = (*i)->derivatives_create_candidate();
nt_name = std::move(x.second);
stmts_record->insert(stmts_record->end(), x.first->begin(), x.first->end());
}
return stmts_record;
return std::make_pair(stmts_record, std::move(nt_name));
}

std::list<Statement::Base*> *Alt::Multi::derivatives_create_candidate() {
std::pair<std::list<Statement::Base*>*, std::string>
Alt::Multi::derivatives_create_candidate() {
throw LogError(
"Alt::Multi::derivatives_create_candidate is not yet implemented!");
return std::make_pair(nullptr, "");
}

void Alt::Base::init_derivative_recording(
AST &ast, std::string *result_name) {
if (ast.current_derivative > 0) {
if (!this->is_partof_outside) {
// test if this alternative uses sub-solutions from other non-terminals
std::list<Statement::Base*> *stmts_record =
derivatives_create_candidate();
std::pair<std::list<Statement::Base*>*, std::string>
stmts_record_p = derivatives_create_candidate();
const std::string &nt_name = stmts_record_p.second;
std::list<Statement::Base*> *stmts_record = stmts_record_p.first;
if (stmts_record && (stmts_record->size() > 0)) {
// TODO(sjanssen): should I use build-in functions? Also for push_back
Statement::Fn_Call *x = new Statement::Fn_Call("set_value");
Expand All @@ -1860,7 +1876,8 @@ void Alt::Base::init_derivative_recording(
if (ast.as_pytorch_module && ast.input.tensor_inputs.all_batched()) {
candidate_name = new std::string("candidate<TensorBatch>");
} else {
candidate_name = new std::string("candidate<>");
candidate_name =
new std::string("candidate<" + nt_name + "_table_t::AnswerType>");
}

Statement::Var_Decl *candidate = new Statement::Var_Decl(
Expand Down
15 changes: 10 additions & 5 deletions src/alt.hh
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,8 @@ class Base {

public:
// for derivative code generation: add code for candidate creation
virtual std::list<Statement::Base*> *derivatives_create_candidate();
virtual std::pair<std::list<Statement::Base*>*, std::string>
derivatives_create_candidate();
};


Expand Down Expand Up @@ -609,7 +610,8 @@ class Simple : public Base {
std::list<Statement::Base*> *inner_code;

public:
std::list<Statement::Base*> *derivatives_create_candidate();
std::pair<std::list<Statement::Base*>*, std::string>
derivatives_create_candidate();
};


Expand Down Expand Up @@ -748,7 +750,8 @@ class Link : public Base {
Alt::Base *find_block_parent(const Alt::Base &block);

public:
std::list<Statement::Base*> *derivatives_create_candidate();
std::pair<std::list<Statement::Base*>*, std::string>
derivatives_create_candidate();
};


Expand Down Expand Up @@ -834,7 +837,8 @@ class Block : public Base {
Alt::Base* find_block();
Alt::Base *find_block_parent(const Alt::Base &block);

std::list<Statement::Base*> *derivatives_create_candidate();
std::pair<std::list<Statement::Base*>*, std::string>
derivatives_create_candidate();
};


Expand Down Expand Up @@ -924,7 +928,8 @@ class Multi : public Base {
Alt::Base* find_block();
Alt::Base *find_block_parent(const Alt::Base &block);

std::list<Statement::Base*> *derivatives_create_candidate();
std::pair<std::list<Statement::Base*>*, std::string>
derivatives_create_candidate();
};

} // namespace Alt
Expand Down
5 changes: 4 additions & 1 deletion src/cpp.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1094,9 +1094,12 @@ void Printer::Cpp::print(const Statement::Table_Decl &t) {
const std::list<Statement::Var_Decl*> &ns = t.ns();

stream << indent() << "class " << tname << " {" << endl;
inc_indent();

stream << indent() << " public:" << endl;
inc_indent();
stream << indent() << "using AnswerType = " << dtype << ';' << endl;
dec_indent();

stream << indent() << " private:" << endl;
inc_indent();

Expand Down
3 changes: 2 additions & 1 deletion src/symbol.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1289,7 +1289,8 @@ void Symbol::NT::codegen(AST &ast) {
if (ast.as_pytorch_module && ast.input.tensor_inputs.all_batched()) {
nt_traces = new std::string("NTtraces<TensorBatch>");
} else {
nt_traces = new std::string("NTtraces<>");
nt_traces =
new std::string("NTtraces<" + *name + "_table_t::AnswerType>");
}
stmts.push_back(new Statement::Var_Decl(new ::Type::External(
nt_traces), "candidates"));
Expand Down
23 changes: 4 additions & 19 deletions src/tablegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ Fn_Def *Tablegen::gen_set_traces(int forDerivative) {
if (batched_) {
nt_traces = new std::string("NTtraces<TensorBatch>");
} else {
nt_traces = new std::string("NTtraces<>");
nt_traces = new std::string("NTtraces<AnswerType>");
}

f->add_para(new ::Type::External(nt_traces), new std::string("candidates"));
Expand All @@ -538,20 +538,11 @@ Fn_Def *Tablegen::gen_set_traces(int forDerivative) {
a->add_arg(new Expr::Less(off, new Expr::Fn_Call(new std::string("size"))));
c.push_back(a);

/*
* if batched Tensor input is processed, the templated trace functions
* need to know that so we need to pass "TensorBatch" as the template here;
* for backwards compatibility, we also need to add the empty template "<>"
* to instruct the compiler to use the default template type (double);
* this isn't required anymore in C++17+, but still is required in C++11
*/
std::string fn_norm_name;
if (forDerivative == 1) {
fn_norm_name = batched_ ? "normalize_traces<TensorBatch>" :
"normalize_traces<>";
fn_norm_name = "normalize_traces";
} else if (forDerivative == 2) {
fn_norm_name = batched_ ? "soft_max_hessian_product<TensorBatch>" :
"soft_max_hessian_product<>";
fn_norm_name = "soft_max_hessian_product";
}

Statement::Fn_Call *fn_norm = new Statement::Fn_Call(fn_norm_name);
Expand Down Expand Up @@ -625,13 +616,7 @@ Fn_Def *Tablegen::gen_get_traces() {

Statement::Var_Decl *r = new Statement::Var_Decl(dtype, "res");

std::string *get_trace_weights;
if (batched_) {
get_trace_weights = new std::string("get_trace_weights<TensorBatch>");
} else {
get_trace_weights = new std::string("get_trace_weights<>");
}

std::string *get_trace_weights = new std::string("get_trace_weights");
Expr::Fn_Call *fn_norm = new Expr::Fn_Call(get_trace_weights);
fn_norm->add_arg(new Var_Acc::Array(new Var_Acc::Plain(
new std::string("traces")), off));
Expand Down

0 comments on commit 54d3e10

Please sign in to comment.