From 98cb898aecb022639bd35c41116cbc25a991ef6d Mon Sep 17 00:00:00 2001 From: xlauko Date: Fri, 10 Jan 2025 14:34:54 +0100 Subject: [PATCH] core: Filter symbol users based on typed references. --- include/vast/Dialect/Core/SymbolTable.hpp | 13 +- lib/vast/Dialect/Core/SymbolTable.cpp | 249 ++++++++++++++-------- 2 files changed, 161 insertions(+), 101 deletions(-) diff --git a/include/vast/Dialect/Core/SymbolTable.hpp b/include/vast/Dialect/Core/SymbolTable.hpp index e693b8b044..322f229c54 100644 --- a/include/vast/Dialect/Core/SymbolTable.hpp +++ b/include/vast/Dialect/Core/SymbolTable.hpp @@ -14,6 +14,7 @@ VAST_UNRELAX_WARNINGS #include "vast/Util/Common.hpp" #include "vast/Util/TypeList.hpp" +#include "vast/Dialect/Core/CoreAttributes.hpp" #include "vast/Dialect/Core/Interfaces/SymbolInterface.hpp" #include @@ -168,18 +169,14 @@ namespace vast::core { // Get all of the uses of the given symbol that are nested within the given // operation 'from'. This does not traverse into any nested symbol tables. - static symbol_use_range get_direct_symbol_uses(operation symbol, operation from); - static symbol_use_range get_direct_symbol_uses(string_attr symbol, operation from); - static symbol_use_range get_direct_symbol_uses(operation symbol, region_ptr from); - static symbol_use_range get_direct_symbol_uses(string_attr symbol, region_ptr from); + static symbol_use_range get_direct_symbol_uses(operation symbol, operation scope); + static symbol_use_range get_direct_symbol_uses(operation symbol, region_ptr scope); // Get all of the uses of the given symbol that are nested within the given // operation 'from'. In contrast to mlir::SymbolTable::getSymbolUses, this // function traverses into nested symbol tables. - static symbol_use_range get_symbol_uses(operation symbol, operation from); - static symbol_use_range get_symbol_uses(string_attr symbol, operation from); - static symbol_use_range get_symbol_uses(operation symbol, region_ptr from); - static symbol_use_range get_symbol_uses(string_attr symbol, region_ptr from); + static symbol_use_range get_symbol_uses(operation symbol, operation scope); + static symbol_use_range get_symbol_uses(operation symbol, region_ptr scope); protected: diff --git a/lib/vast/Dialect/Core/SymbolTable.cpp b/lib/vast/Dialect/Core/SymbolTable.cpp index 0775132759..c13db542c7 100644 --- a/lib/vast/Dialect/Core/SymbolTable.cpp +++ b/lib/vast/Dialect/Core/SymbolTable.cpp @@ -109,75 +109,164 @@ namespace vast::core { return op->getAttrOfType< string_attr >(symbol_attr_name()); } - bool has_symbol_ref_attr(operation op, symbol_ref_attr symbol) { - return op->getAttrDictionary().walk< mlir::WalkOrder::PreOrder >( - [symbol](mlir::SymbolRefAttr attr) { - VAST_ASSERT(attr.getNestedReferences().empty()); - VAST_ASSERT(symbol.getNestedReferences().empty()); - if (attr.getRootReference() == symbol.getRootReference()) { + // FIXME: This is not nice. Can this be inferred from the symbol reference directly? + enum class reference_kind { var, type, func, label, member, enum_constant, elaborated_type }; + + reference_kind get_reference_kind(symbol_ref_attr attr) { + if (mlir::isa< var_symbol_ref_attr >(attr)) + return reference_kind::var; + if (mlir::isa< type_symbol_ref_attr >(attr)) + return reference_kind::type; + if (mlir::isa< func_symbol_ref_attr >(attr)) + return reference_kind::func; + if (mlir::isa< label_symbol_ref_attr >(attr)) + return reference_kind::label; + if (mlir::isa< member_var_symbol_ref_attr >(attr)) + return reference_kind::member; + if (mlir::isa< enum_constant_symbol_ref_attr >(attr)) + return reference_kind::enum_constant; + if (mlir::isa< elaborated_type_symbol_ref_attr >(attr)) + return reference_kind::elaborated_type; + + VAST_UNREACHABLE("unrecognized reference kind"); + } + + reference_kind get_reference_kind(operation symbol) { + if (mlir::isa< var_symbol >(symbol)) + return reference_kind::var; + if (mlir::isa< type_symbol >(symbol)) + return reference_kind::type; + if (mlir::isa< func_symbol >(symbol)) + return reference_kind::func; + if (mlir::isa< label_symbol >(symbol)) + return reference_kind::label; + if (mlir::isa< member_symbol >(symbol)) + return reference_kind::member; + if (mlir::isa< enum_constant_symbol >(symbol)) + return reference_kind::enum_constant; + if (mlir::isa< elaborated_type_symbol >(symbol)) + return reference_kind::elaborated_type; + + VAST_UNREACHABLE("unrecognized reference kind"); + } + + bool is_reference_of(symbol_ref_attr attr, operation symbol) { + return get_reference_kind(attr) == get_reference_kind(symbol); + } + + symbol_ref_attr get_symbol_ref_attr(operation op, operation symbol) { + symbol_ref_attr result; + op->getAttrDictionary().walk< mlir::WalkOrder::PreOrder >( + [&] (symbol_ref_attr attr) { + if (!is_reference_of(attr, symbol)) { + return mlir::WalkResult::skip(); + } + + if (attr.getRootReference() == get_symbol_name(symbol)) { + result = attr; return mlir::WalkResult::interrupt(); } // Don't walk nested references. return mlir::WalkResult::skip(); } - ) == mlir::WalkResult::interrupt(); + ); + + return result; + } + + struct symbol_scope { + // The first effective operation in the scope + // allows to reduce the search space in the region. + // + // Can be used if definition of symbol is in middle of region, we want + // to look at references only after the definition. + // + // If scope_begin is not set, the scope is the region itself. + operation scope_begin; + region_ptr scope; + }; + + gap::generator< operation > operations(symbol_scope region) { + // TBD: use proper dominance analysis + for (auto &bb : *region.scope) { + for (auto &op : bb) { + if (&bb != region.scope_begin->getBlock()) { + co_yield &op; + } else if (!region.scope_begin || region.scope_begin->isBeforeInBlock(&op)) { + co_yield &op; + } + } + } + } + + std::optional< symbol_scope > constrain_ancestor_scope(operation scope, operation symbol) { + if (symbol->getParentRegion()->isAncestor(scope->getParentRegion())) + return symbol_scope{ symbol, symbol->getParentRegion() }; + return std::nullopt; + } + + std::optional< symbol_scope > constrain_ancestor_scope(region_ptr scope, operation symbol) { + if (symbol->getParentRegion()->isAncestor(scope)) + return symbol_scope{ symbol, scope }; + return std::nullopt; + } + + gap::generator< symbol_scope > symbol_scopes(auto scope, operation symbol) { + // If symbol is defined in ancestor region of scope, return the most immediate + // region and constrain it to scope from symbol definition ownwards + if (auto constrained = constrain_ancestor_scope(scope, symbol)) { + co_yield *constrained; + } else { + // Else symbol is defined above scope, therefore references can be anywhere in scope + for (auto region : direct_regions(scope)) { + co_yield symbol_scope{ nullptr, region }; + } + } } - auto operations(region_ptr region) { return vws::all(*region) | vws::join; } - auto operations(operation op) { return gmw::operations(op); } + auto direct_symbol_uses_in_scope(operation symbol, symbol_scope scope) + -> gap::generator< symbol_use > + { + for (auto op : operations(scope)) { + if (auto symbol_ref = get_symbol_ref_attr(op, symbol)) { + co_yield symbol_use{ op, symbol_ref }; + } + } + } - auto direct_symbol_uses_in_scope(symbol_ref_attr symbol, auto scope) -> gap::generator< symbol_use > { - for (auto &op : operations(scope)) { - if (has_symbol_ref_attr(&op, symbol)) { - co_yield symbol_use{ &op, symbol }; + auto direct_symbol_uses_in_scope(operation symbol, auto scope) + -> gap::generator< symbol_use > + { + for (auto scope : symbol_scopes(scope, symbol)) { + for (auto use : direct_symbol_uses_in_scope(symbol, scope)) { + co_yield use; } } } + namespace detail { - symbol_use_range get_direct_symbol_uses_impl(symbol_ref_attr symbol, auto root) { + symbol_use_range get_direct_symbol_uses_impl(operation symbol, auto root) { VAST_ASSERT(symbol); std::vector< symbol_use > uses; + for (auto use : direct_symbol_uses_in_scope(symbol, root)) { uses.push_back(use); } - return symbol_use_range(std::move(uses)); - } - symbol_use_range get_direct_symbol_uses_impl(string_attr symbol, auto root) { - return get_direct_symbol_uses_impl(symbol_ref_attr::get(symbol), root); - } - - symbol_use_range get_direct_symbol_uses_impl(operation symbol, auto root) { - return get_direct_symbol_uses_impl(get_symbol_name(symbol), root); + return symbol_use_range(std::move(uses)); } } // namespace detail - symbol_use_range symbol_table::get_direct_symbol_uses( - operation symbol, operation from - ) { - return detail::get_direct_symbol_uses_impl(symbol, from); + symbol_use_range symbol_table::get_direct_symbol_uses(operation symbol, operation scope) { + return detail::get_direct_symbol_uses_impl(symbol, scope); } - symbol_use_range symbol_table::get_direct_symbol_uses( - string_attr symbol, operation from - ) { - return detail::get_direct_symbol_uses_impl(symbol, from); - } - - symbol_use_range symbol_table::get_direct_symbol_uses( - operation symbol, region_ptr from - ) { - return detail::get_direct_symbol_uses_impl(symbol, from); - } - - symbol_use_range symbol_table::get_direct_symbol_uses( - string_attr symbol, region_ptr from - ) { - return detail::get_direct_symbol_uses_impl(symbol, from); + symbol_use_range symbol_table::get_direct_symbol_uses(operation symbol, region_ptr scope) { + return detail::get_direct_symbol_uses_impl(symbol, scope); } // @@ -185,71 +274,45 @@ namespace vast::core { // namespace detail { - auto nested_scopes(operation root) -> gap::recursive_generator< region_ptr >; - - auto nested_scopes(region_ptr root) -> gap::recursive_generator< region_ptr > { - co_yield root; - for (auto &op : operations(root)) { - co_yield nested_scopes(&op); - } - } - - auto nested_scopes(operation root) -> gap::recursive_generator< region_ptr > { - if (root->getNumRegions() == 0) - co_return; - for (auto region : direct_regions(root)) { - co_yield nested_scopes(region); - } - } - symbol_use_range get_symbol_uses_impl(symbol_ref_attr symbol, auto root) { + symbol_use_range get_symbol_uses_impl(operation symbol, auto scope) { VAST_ASSERT(symbol); std::vector< symbol_use > uses; - for (auto scope : nested_scopes(root)) { - for (auto use : direct_symbol_uses_in_scope(symbol, scope)) { - uses.push_back(use); + + auto symbol_region = symbol->getParentRegion(); + scope->walk([&](operation op) { + if (op->getNumRegions() == 0) + return mlir::WalkResult::skip(); + + bool skip = true; + for (auto ®ion : op->getRegions()) { + if (symbol_region->isAncestor(®ion)) { + for (auto use : direct_symbol_uses_in_scope(symbol, ®ion)) { + uses.push_back(use); + } + skip = false; + } } - } - return symbol_use_range(std::move(uses)); - } - symbol_use_range get_symbol_uses_impl(string_attr symbol, auto root) { - return get_symbol_uses_impl(symbol_ref_attr::get(symbol), root); - } + return skip ? mlir::WalkResult::skip() : mlir::WalkResult::advance(); + }); - symbol_use_range get_symbol_uses_impl(operation symbol, auto root) { - return get_symbol_uses_impl(get_symbol_name(symbol), root); + return symbol_use_range(std::move(uses)); } - } // namespace detail - symbol_use_range symbol_table::get_symbol_uses( - operation symbol, operation from - ) { - return detail::get_symbol_uses_impl(symbol, from); - } + } // namespace detail - symbol_use_range symbol_table::get_symbol_uses( - string_attr symbol, operation from - ) { - return detail::get_symbol_uses_impl(symbol, from); + symbol_use_range symbol_table::get_symbol_uses(operation symbol, operation scope) { + return detail::get_symbol_uses_impl(symbol, scope); } - symbol_use_range symbol_table::get_symbol_uses( - operation symbol, region_ptr from - ) { - return detail::get_symbol_uses_impl(symbol, from); + symbol_use_range symbol_table::get_symbol_uses(operation symbol, region_ptr scope) { + return detail::get_symbol_uses_impl(symbol, scope); } - symbol_use_range symbol_table::get_symbol_uses( - string_attr symbol, region_ptr from - ) { - return detail::get_symbol_uses_impl(symbol, from); - } - symbol_use_range get_symbol_uses( - operation symbol, operation from - ) { - return symbol_table::get_symbol_uses(symbol, from); + symbol_use_range get_symbol_uses(operation symbol, operation scope) { + return symbol_table::get_symbol_uses(symbol, scope); } } // namespace vast::core