Skip to content

Commit

Permalink
core: Filter symbol users based on typed references.
Browse files Browse the repository at this point in the history
  • Loading branch information
xlauko committed Jan 10, 2025
1 parent ed7365a commit 98cb898
Show file tree
Hide file tree
Showing 2 changed files with 161 additions and 101 deletions.
13 changes: 5 additions & 8 deletions include/vast/Dialect/Core/SymbolTable.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ VAST_UNRELAX_WARNINGS
#include "vast/Util/Common.hpp"
#include "vast/Util/TypeList.hpp"

#include "vast/Dialect/Core/CoreAttributes.hpp"
#include "vast/Dialect/Core/Interfaces/SymbolInterface.hpp"

#include <gap/coro/generator.hpp>
Expand Down Expand Up @@ -168,18 +169,14 @@ namespace vast::core {

// Get all of the uses of the given symbol that are nested within the given
// operation 'from'. This does not traverse into any nested symbol tables.
static symbol_use_range get_direct_symbol_uses(operation symbol, operation from);
static symbol_use_range get_direct_symbol_uses(string_attr symbol, operation from);
static symbol_use_range get_direct_symbol_uses(operation symbol, region_ptr from);
static symbol_use_range get_direct_symbol_uses(string_attr symbol, region_ptr from);
static symbol_use_range get_direct_symbol_uses(operation symbol, operation scope);
static symbol_use_range get_direct_symbol_uses(operation symbol, region_ptr scope);

// Get all of the uses of the given symbol that are nested within the given
// operation 'from'. In contrast to mlir::SymbolTable::getSymbolUses, this
// function traverses into nested symbol tables.
static symbol_use_range get_symbol_uses(operation symbol, operation from);
static symbol_use_range get_symbol_uses(string_attr symbol, operation from);
static symbol_use_range get_symbol_uses(operation symbol, region_ptr from);
static symbol_use_range get_symbol_uses(string_attr symbol, region_ptr from);
static symbol_use_range get_symbol_uses(operation symbol, operation scope);
static symbol_use_range get_symbol_uses(operation symbol, region_ptr scope);

protected:

Expand Down
249 changes: 156 additions & 93 deletions lib/vast/Dialect/Core/SymbolTable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,147 +109,210 @@ namespace vast::core {
return op->getAttrOfType< string_attr >(symbol_attr_name());
}

bool has_symbol_ref_attr(operation op, symbol_ref_attr symbol) {
return op->getAttrDictionary().walk< mlir::WalkOrder::PreOrder >(
[symbol](mlir::SymbolRefAttr attr) {
VAST_ASSERT(attr.getNestedReferences().empty());
VAST_ASSERT(symbol.getNestedReferences().empty());
if (attr.getRootReference() == symbol.getRootReference()) {
// FIXME: This is not nice. Can this be inferred from the symbol reference directly?
enum class reference_kind { var, type, func, label, member, enum_constant, elaborated_type };

reference_kind get_reference_kind(symbol_ref_attr attr) {
if (mlir::isa< var_symbol_ref_attr >(attr))
return reference_kind::var;
if (mlir::isa< type_symbol_ref_attr >(attr))
return reference_kind::type;
if (mlir::isa< func_symbol_ref_attr >(attr))
return reference_kind::func;
if (mlir::isa< label_symbol_ref_attr >(attr))
return reference_kind::label;
if (mlir::isa< member_var_symbol_ref_attr >(attr))
return reference_kind::member;
if (mlir::isa< enum_constant_symbol_ref_attr >(attr))
return reference_kind::enum_constant;
if (mlir::isa< elaborated_type_symbol_ref_attr >(attr))
return reference_kind::elaborated_type;

VAST_UNREACHABLE("unrecognized reference kind");
}

reference_kind get_reference_kind(operation symbol) {
if (mlir::isa< var_symbol >(symbol))
return reference_kind::var;
if (mlir::isa< type_symbol >(symbol))
return reference_kind::type;
if (mlir::isa< func_symbol >(symbol))
return reference_kind::func;
if (mlir::isa< label_symbol >(symbol))
return reference_kind::label;
if (mlir::isa< member_symbol >(symbol))
return reference_kind::member;
if (mlir::isa< enum_constant_symbol >(symbol))
return reference_kind::enum_constant;
if (mlir::isa< elaborated_type_symbol >(symbol))
return reference_kind::elaborated_type;

VAST_UNREACHABLE("unrecognized reference kind");
}

bool is_reference_of(symbol_ref_attr attr, operation symbol) {
return get_reference_kind(attr) == get_reference_kind(symbol);
}

symbol_ref_attr get_symbol_ref_attr(operation op, operation symbol) {
symbol_ref_attr result;
op->getAttrDictionary().walk< mlir::WalkOrder::PreOrder >(
[&] (symbol_ref_attr attr) {
if (!is_reference_of(attr, symbol)) {
return mlir::WalkResult::skip();
}

if (attr.getRootReference() == get_symbol_name(symbol)) {
result = attr;
return mlir::WalkResult::interrupt();
}

// Don't walk nested references.
return mlir::WalkResult::skip();
}
) == mlir::WalkResult::interrupt();
);

return result;
}

struct symbol_scope {
// The first effective operation in the scope
// allows to reduce the search space in the region.
//
// Can be used if definition of symbol is in middle of region, we want
// to look at references only after the definition.
//
// If scope_begin is not set, the scope is the region itself.
operation scope_begin;
region_ptr scope;
};

gap::generator< operation > operations(symbol_scope region) {
// TBD: use proper dominance analysis
for (auto &bb : *region.scope) {
for (auto &op : bb) {
if (&bb != region.scope_begin->getBlock()) {
co_yield &op;
} else if (!region.scope_begin || region.scope_begin->isBeforeInBlock(&op)) {
co_yield &op;
}
}
}
}

std::optional< symbol_scope > constrain_ancestor_scope(operation scope, operation symbol) {
if (symbol->getParentRegion()->isAncestor(scope->getParentRegion()))
return symbol_scope{ symbol, symbol->getParentRegion() };
return std::nullopt;
}

std::optional< symbol_scope > constrain_ancestor_scope(region_ptr scope, operation symbol) {
if (symbol->getParentRegion()->isAncestor(scope))
return symbol_scope{ symbol, scope };
return std::nullopt;
}

gap::generator< symbol_scope > symbol_scopes(auto scope, operation symbol) {
// If symbol is defined in ancestor region of scope, return the most immediate
// region and constrain it to scope from symbol definition ownwards
if (auto constrained = constrain_ancestor_scope(scope, symbol)) {
co_yield *constrained;
} else {
// Else symbol is defined above scope, therefore references can be anywhere in scope
for (auto region : direct_regions(scope)) {
co_yield symbol_scope{ nullptr, region };
}
}
}

auto operations(region_ptr region) { return vws::all(*region) | vws::join; }
auto operations(operation op) { return gmw::operations(op); }
auto direct_symbol_uses_in_scope(operation symbol, symbol_scope scope)
-> gap::generator< symbol_use >
{
for (auto op : operations(scope)) {
if (auto symbol_ref = get_symbol_ref_attr(op, symbol)) {
co_yield symbol_use{ op, symbol_ref };
}
}
}

auto direct_symbol_uses_in_scope(symbol_ref_attr symbol, auto scope) -> gap::generator< symbol_use > {
for (auto &op : operations(scope)) {
if (has_symbol_ref_attr(&op, symbol)) {
co_yield symbol_use{ &op, symbol };
auto direct_symbol_uses_in_scope(operation symbol, auto scope)
-> gap::generator< symbol_use >
{
for (auto scope : symbol_scopes(scope, symbol)) {
for (auto use : direct_symbol_uses_in_scope(symbol, scope)) {
co_yield use;
}
}
}


namespace detail {

symbol_use_range get_direct_symbol_uses_impl(symbol_ref_attr symbol, auto root) {
symbol_use_range get_direct_symbol_uses_impl(operation symbol, auto root) {
VAST_ASSERT(symbol);
std::vector< symbol_use > uses;

for (auto use : direct_symbol_uses_in_scope(symbol, root)) {
uses.push_back(use);
}
return symbol_use_range(std::move(uses));
}

symbol_use_range get_direct_symbol_uses_impl(string_attr symbol, auto root) {
return get_direct_symbol_uses_impl(symbol_ref_attr::get(symbol), root);
}

symbol_use_range get_direct_symbol_uses_impl(operation symbol, auto root) {
return get_direct_symbol_uses_impl(get_symbol_name(symbol), root);
return symbol_use_range(std::move(uses));
}

} // namespace detail

symbol_use_range symbol_table::get_direct_symbol_uses(
operation symbol, operation from
) {
return detail::get_direct_symbol_uses_impl(symbol, from);
symbol_use_range symbol_table::get_direct_symbol_uses(operation symbol, operation scope) {
return detail::get_direct_symbol_uses_impl(symbol, scope);
}

symbol_use_range symbol_table::get_direct_symbol_uses(
string_attr symbol, operation from
) {
return detail::get_direct_symbol_uses_impl(symbol, from);
}

symbol_use_range symbol_table::get_direct_symbol_uses(
operation symbol, region_ptr from
) {
return detail::get_direct_symbol_uses_impl(symbol, from);
}

symbol_use_range symbol_table::get_direct_symbol_uses(
string_attr symbol, region_ptr from
) {
return detail::get_direct_symbol_uses_impl(symbol, from);
symbol_use_range symbol_table::get_direct_symbol_uses(operation symbol, region_ptr scope) {
return detail::get_direct_symbol_uses_impl(symbol, scope);
}

//
// symbol uses
//

namespace detail {
auto nested_scopes(operation root) -> gap::recursive_generator< region_ptr >;

auto nested_scopes(region_ptr root) -> gap::recursive_generator< region_ptr > {
co_yield root;
for (auto &op : operations(root)) {
co_yield nested_scopes(&op);
}
}

auto nested_scopes(operation root) -> gap::recursive_generator< region_ptr > {
if (root->getNumRegions() == 0)
co_return;
for (auto region : direct_regions(root)) {
co_yield nested_scopes(region);
}
}

symbol_use_range get_symbol_uses_impl(symbol_ref_attr symbol, auto root) {
symbol_use_range get_symbol_uses_impl(operation symbol, auto scope) {
VAST_ASSERT(symbol);
std::vector< symbol_use > uses;
for (auto scope : nested_scopes(root)) {
for (auto use : direct_symbol_uses_in_scope(symbol, scope)) {
uses.push_back(use);

auto symbol_region = symbol->getParentRegion();
scope->walk([&](operation op) {
if (op->getNumRegions() == 0)
return mlir::WalkResult::skip();

bool skip = true;
for (auto &region : op->getRegions()) {
if (symbol_region->isAncestor(&region)) {
for (auto use : direct_symbol_uses_in_scope(symbol, &region)) {
uses.push_back(use);
}
skip = false;
}
}
}
return symbol_use_range(std::move(uses));
}

symbol_use_range get_symbol_uses_impl(string_attr symbol, auto root) {
return get_symbol_uses_impl(symbol_ref_attr::get(symbol), root);
}
return skip ? mlir::WalkResult::skip() : mlir::WalkResult::advance();
});

symbol_use_range get_symbol_uses_impl(operation symbol, auto root) {
return get_symbol_uses_impl(get_symbol_name(symbol), root);
return symbol_use_range(std::move(uses));
}
} // namespace detail

symbol_use_range symbol_table::get_symbol_uses(
operation symbol, operation from
) {
return detail::get_symbol_uses_impl(symbol, from);
}
} // namespace detail

symbol_use_range symbol_table::get_symbol_uses(
string_attr symbol, operation from
) {
return detail::get_symbol_uses_impl(symbol, from);
symbol_use_range symbol_table::get_symbol_uses(operation symbol, operation scope) {
return detail::get_symbol_uses_impl(symbol, scope);
}

symbol_use_range symbol_table::get_symbol_uses(
operation symbol, region_ptr from
) {
return detail::get_symbol_uses_impl(symbol, from);
symbol_use_range symbol_table::get_symbol_uses(operation symbol, region_ptr scope) {
return detail::get_symbol_uses_impl(symbol, scope);
}

symbol_use_range symbol_table::get_symbol_uses(
string_attr symbol, region_ptr from
) {
return detail::get_symbol_uses_impl(symbol, from);
}

symbol_use_range get_symbol_uses(
operation symbol, operation from
) {
return symbol_table::get_symbol_uses(symbol, from);
symbol_use_range get_symbol_uses(operation symbol, operation scope) {
return symbol_table::get_symbol_uses(symbol, scope);
}

} // namespace vast::core

0 comments on commit 98cb898

Please sign in to comment.