Skip to content

Commit

Permalink
introduce a version of namespace_iter with a lambda function
Browse files Browse the repository at this point in the history
  • Loading branch information
krangelov committed Dec 28, 2023
1 parent d78aea4 commit 87b6094
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 95 deletions.
18 changes: 18 additions & 0 deletions src/runtime/c/pgf/namespace.h
Original file line number Diff line number Diff line change
Expand Up @@ -631,6 +631,24 @@ void namespace_iter(Namespace<V> map, PgfItor* itor, PgfExn *err)
return;
}

template <class V>
bool namespace_iter(Namespace<V> map, std::function<bool(ref<V>)> &f)
{
if (map == 0)
return true;

if (!namespace_iter(map->left, f))
return false;

if (!f(map->value))
return false;

if (!namespace_iter(map->right, f))
return false;

return true;
}

template <class V>
void namespace_iter_prefix(Namespace<V> map, PgfText *prefix, PgfItor* itor, PgfExn *err)
{
Expand Down
57 changes: 14 additions & 43 deletions src/runtime/c/pgf/pgf.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -463,23 +463,6 @@ void pgf_iter_categories(PgfDB *db, PgfRevision revision,
} PGF_API_END
}

struct PgfItorConcrHelper : PgfItor
{
PgfDB *db;
txn_t txn_id;
PgfItor *itor;
};

static
void iter_concretes_helper(PgfItor *itor, PgfText *key, object value, PgfExn *err)
{
PgfItorConcrHelper* helper = (PgfItorConcrHelper*) itor;
ref<PgfConcr> concr = value;
object rev = helper->db->register_revision(concr.tagged(), helper->txn_id);
helper->db->ref_count++;
helper->itor->fn(helper->itor, key, rev, err);
}

PGF_API
void pgf_iter_concretes(PgfDB *db, PgfRevision revision,
PgfItor *itor, PgfExn *err)
Expand All @@ -490,13 +473,14 @@ void pgf_iter_concretes(PgfDB *db, PgfRevision revision,
DB_scope scope(db, READER_SCOPE);
ref<PgfPGF> pgf = db->revision2pgf(revision, &txn_id);

PgfItorConcrHelper helper;
helper.fn = iter_concretes_helper;
helper.db = db;
helper.txn_id = txn_id;
helper.itor = itor;

namespace_iter(pgf->concretes, &helper, err);
std::function<bool(ref<PgfConcr>)> f =
[txn_id,db,itor,err](ref<PgfConcr> concr) {
object rev = db->register_revision(concr.tagged(), txn_id);
db->ref_count++;
itor->fn(itor, &concr->name, rev, err);
return (err->type == PGF_EXN_NONE);
};
namespace_iter(pgf->concretes, f);
} PGF_API_END
}

Expand Down Expand Up @@ -1609,30 +1593,19 @@ void pgf_create_category(PgfDB *db, PgfRevision revision,
struct PGF_INTERNAL_DECL PgfDropItor : PgfItor
{
ref<PgfPGF> pgf;
ref<PgfConcr> concrete;
PgfText *name;
};

static
void iter_drop_cat_helper2(PgfItor *itor, PgfText *key, object value, PgfExn *err)
{
ref<PgfConcr> concr = value;
PgfText* name = ((PgfDropItor*) itor)->name;

drop_lin(concr, name);
}

static
void iter_drop_cat_helper(PgfItor *itor, PgfText *key, object value, PgfExn *err)
{
ref<PgfPGF> pgf = ((PgfDropItor*) itor)->pgf;

PgfDropItor itor2;
itor2.fn = iter_drop_cat_helper2;
itor2.pgf = 0;
itor2.concrete = 0;
itor2.name = key;
namespace_iter(pgf->concretes, &itor2, err);
std::function<bool(ref<PgfConcr>)> f =
[key,err](ref<PgfConcr> concr) {
drop_lin(concr, key);
return (err->type == PGF_EXN_NONE);
};
namespace_iter(pgf->concretes, f);

ref<PgfAbsFun> fun;
Namespace<PgfAbsFun> funs =
Expand Down Expand Up @@ -1672,8 +1645,6 @@ void pgf_drop_category(PgfDB *db, PgfRevision revision,
PgfDropItor itor;
itor.fn = iter_drop_cat_helper;
itor.pgf = pgf;
itor.concrete = 0;
itor.name = name;
PgfProbspace funs_by_cat =
probspace_delete_by_cat(pgf->abstract.funs_by_cat, &cat->name,
&itor, err);
Expand Down
85 changes: 33 additions & 52 deletions src/runtime/c/pgf/reader.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -373,11 +373,6 @@ struct PGF_INTERNAL_DECL PgfAbsCatCounts
prob_t prob;
};

struct PGF_INTERNAL_DECL PgfProbItor : PgfItor
{
Vector<PgfAbsCatCounts> *cats;
};

static
PgfAbsCatCounts *find_counts(Vector<PgfAbsCatCounts> *cats, PgfText *name)
{
Expand All @@ -399,38 +394,6 @@ PgfAbsCatCounts *find_counts(Vector<PgfAbsCatCounts> *cats, PgfText *name)
return NULL;
}

static
void collect_counts(PgfItor *itor, PgfText *key, object value, PgfExn *err)
{
PgfProbItor* prob_itor = (PgfProbItor*) itor;
ref<PgfAbsFun> absfun = value;

PgfAbsCatCounts *counts =
find_counts(prob_itor->cats, &absfun->type->name);
if (counts != NULL) {
if (isnan(absfun->prob)) {
counts->n_nan_probs++;
} else {
counts->probs_sum += exp(-absfun->prob);
}
}
}

static
void pad_probs(PgfItor *itor, PgfText *key, object value, PgfExn *err)
{
PgfProbItor* prob_itor = (PgfProbItor*) itor;
ref<PgfAbsFun> absfun = value;

if (isnan(absfun->prob)) {
PgfAbsCatCounts *counts =
find_counts(prob_itor->cats, &absfun->type->name);
if (counts != NULL) {
absfun->prob = counts->prob;
}
}
}

void PgfReader::read_abstract(ref<PgfAbstr> abstract)
{
this->abstract = abstract;
Expand All @@ -447,24 +410,42 @@ void PgfReader::read_abstract(ref<PgfAbstr> abstract)
abstract->cats = cats;

if (probs_callback != NULL) {
PgfExn err;
err.type = PGF_EXN_NONE;

PgfProbItor itor;
itor.cats = namespace_to_sorted_names<PgfAbsCat,PgfAbsCatCounts>(abstract->cats);

itor.fn = collect_counts;
namespace_iter(abstract->funs, &itor, &err);

for (size_t i = 0; i < itor.cats->len; i++) {
PgfAbsCatCounts *counts = &itor.cats->data[i];
Vector<PgfAbsCatCounts> *cats = namespace_to_sorted_names<PgfAbsCat,PgfAbsCatCounts>(abstract->cats);

std::function<bool(ref<PgfAbsFun>)> collect_counts =
[cats](ref<PgfAbsFun> absfun) {
PgfAbsCatCounts *counts =
find_counts(cats, &absfun->type->name);
if (counts != NULL) {
if (isnan(absfun->prob)) {
counts->n_nan_probs++;
} else {
counts->probs_sum += exp(-absfun->prob);
}
}
return true;
};
namespace_iter(abstract->funs, collect_counts);

for (size_t i = 0; i < cats->len; i++) {
PgfAbsCatCounts *counts = &cats->data[i];
counts->prob = - logf((1-counts->probs_sum) / counts->n_nan_probs);
}

itor.fn = pad_probs;
namespace_iter(abstract->funs, &itor, &err);

free(itor.cats);
std::function<bool(ref<PgfAbsFun>)> pad_probs =
[cats](ref<PgfAbsFun> absfun) {
if (isnan(absfun->prob)) {
PgfAbsCatCounts *counts =
find_counts(cats, &absfun->type->name);
if (counts != NULL) {
absfun->prob = counts->prob;
}
}
return true;
};
namespace_iter(abstract->funs, pad_probs);

free(cats);
}
}

Expand Down

0 comments on commit 87b6094

Please sign in to comment.