Skip to content

Commit 5508c1e

Browse files
authored
[OM] Evaluator: Support graph regions (#6249)
This commit adds support for graph regions for evaluator. `ReferenceValue`, a new subclass of `EvaluatorValue`, is added to behave as pointers. `ReferenceValue` can be used as alias to different values and is created for `class.object.field` operation because `class.object.field` can access fields across class hierarchies and the fields might not be initilized yet. `RefenceValue` is not exposed to outside of evaluator implemenation. `EvaluatorValue::finalize` shrinks intermidiate `RefenceValue` in the evaluator value. Evaluation algorithm is changed to worklist-based iteration. `instantiae` method first traverses the whole IR including sub-class, and create partially evaluaed values for all values and add these values to the worklist. After that we evaluate values until there is no partially evaluaed value. Fix #5834
1 parent 307fb58 commit 5508c1e

File tree

6 files changed

+660
-131
lines changed

6 files changed

+660
-131
lines changed

include/circt/Dialect/OM/Evaluator/Evaluator.h

Lines changed: 213 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@
2020
#include "mlir/IR/MLIRContext.h"
2121
#include "mlir/IR/SymbolTable.h"
2222
#include "mlir/Support/LogicalResult.h"
23+
#include "llvm/ADT/SmallPtrSet.h"
24+
25+
#include <queue>
26+
#include <utility>
2327

2428
namespace circt {
2529
namespace om {
@@ -36,25 +40,84 @@ using EvaluatorValuePtr = std::shared_ptr<EvaluatorValue>;
3640
using ObjectFields = SmallDenseMap<StringAttr, EvaluatorValuePtr>;
3741

3842
/// Base class for evaluator runtime values.
39-
/// Enables the shared_from_this functionality so Object pointers can be passed
40-
/// through the CAPI and unwrapped back into C++ smart pointers with the
41-
/// appropriate reference count.
43+
/// Enables the shared_from_this functionality so Evaluator Value pointers can
44+
/// be passed through the CAPI and unwrapped back into C++ smart pointers with
45+
/// the appropriate reference count.
4246
struct EvaluatorValue : std::enable_shared_from_this<EvaluatorValue> {
4347
// Implement LLVM RTTI.
44-
enum class Kind { Attr, Object, List, Tuple, Map };
45-
EvaluatorValue(MLIRContext *ctx, Kind kind) : kind(kind) {}
48+
enum class Kind { Attr, Object, List, Tuple, Map, Reference };
49+
EvaluatorValue(MLIRContext *ctx, Kind kind) : kind(kind), ctx(ctx) {}
4650
Kind getKind() const { return kind; }
4751
MLIRContext *getContext() const { return ctx; }
4852

53+
// Return true the value is fully evaluated.
54+
bool isFullyEvaluated() const { return fullyEvaluated; }
55+
void markFullyEvaluated() {
56+
assert(!fullyEvaluated && "should not mark twice");
57+
fullyEvaluated = true;
58+
}
59+
60+
// Return a MLIR type which the value represents.
61+
Type getType() const;
62+
63+
// Finalize the evaluator value. Strip intermidiate reference values.
64+
LogicalResult finalize();
65+
4966
private:
5067
const Kind kind;
5168
MLIRContext *ctx;
69+
bool fullyEvaluated = false;
70+
bool finalized = false;
71+
};
72+
73+
/// Values which can be used as pointers to different values.
74+
/// ReferenceValue is replaced with its element and erased at the end of
75+
/// evaluation.
76+
struct ReferenceValue : EvaluatorValue {
77+
ReferenceValue(Type type)
78+
: EvaluatorValue(type.getContext(), Kind::Reference), value(nullptr),
79+
type(type) {}
80+
81+
// Implement LLVM RTTI.
82+
static bool classof(const EvaluatorValue *e) {
83+
return e->getKind() == Kind::Reference;
84+
}
85+
86+
Type getValueType() const { return type; }
87+
EvaluatorValuePtr getValue() const { return value; }
88+
void setValue(EvaluatorValuePtr newValue) {
89+
value = std::move(newValue);
90+
markFullyEvaluated();
91+
}
92+
93+
// Finalize the value.
94+
LogicalResult finalizeImpl();
95+
96+
// Return the first non-reference value that is reachable from the reference.
97+
FailureOr<EvaluatorValuePtr> getStrippedValue() const {
98+
llvm::SmallPtrSet<ReferenceValue *, 4> visited;
99+
auto currentValue = value;
100+
while (auto *v = dyn_cast<ReferenceValue>(currentValue.get())) {
101+
// Detect a cycle.
102+
if (!visited.insert(v).second)
103+
return failure();
104+
currentValue = v->getValue();
105+
}
106+
return success(currentValue);
107+
}
108+
109+
private:
110+
EvaluatorValuePtr value;
111+
Type type;
52112
};
53113

54114
/// Values which can be directly representable by MLIR attributes.
55115
struct AttributeValue : EvaluatorValue {
56116
AttributeValue(Attribute attr)
57-
: EvaluatorValue(attr.getContext(), Kind::Attr), attr(attr) {}
117+
: EvaluatorValue(attr.getContext(), Kind::Attr), attr(attr) {
118+
markFullyEvaluated();
119+
}
120+
58121
Attribute getAttr() const { return attr; }
59122
template <typename AttrTy>
60123
AttrTy getAs() const {
@@ -64,20 +127,52 @@ struct AttributeValue : EvaluatorValue {
64127
return e->getKind() == Kind::Attr;
65128
}
66129

130+
// Finalize the value.
131+
LogicalResult finalizeImpl() { return success(); }
132+
133+
Type getType() const { return attr.cast<TypedAttr>().getType(); }
134+
67135
private:
68-
Attribute attr;
136+
Attribute attr = {};
69137
};
70138

139+
// This perform finalization to `value`.
140+
static inline LogicalResult finalizeEvaluatorValue(EvaluatorValuePtr &value) {
141+
if (failed(value->finalize()))
142+
return failure();
143+
if (auto *ref = llvm::dyn_cast<ReferenceValue>(value.get())) {
144+
auto v = ref->getStrippedValue();
145+
if (failed(v))
146+
return v;
147+
value = v.value();
148+
}
149+
return success();
150+
}
151+
71152
/// A List which contains variadic length of elements with the same type.
72153
struct ListValue : EvaluatorValue {
73154
ListValue(om::ListType type, SmallVector<EvaluatorValuePtr> elements)
74155
: EvaluatorValue(type.getContext(), Kind::List), type(type),
75-
elements(std::move(elements)) {}
156+
elements(std::move(elements)) {
157+
markFullyEvaluated();
158+
}
159+
160+
void setElements(SmallVector<EvaluatorValuePtr> newElements) {
161+
elements = std::move(newElements);
162+
markFullyEvaluated();
163+
}
164+
165+
// Finalize the value.
166+
LogicalResult finalizeImpl();
167+
168+
// Partially evaluated value.
169+
ListValue(om::ListType type)
170+
: EvaluatorValue(type.getContext(), Kind::List), type(type) {}
76171

77172
const auto &getElements() const { return elements; }
78173

79174
/// Return the type of the value, which is a ListType.
80-
om::ListType getType() const { return type; }
175+
om::ListType getListType() const { return type; }
81176

82177
/// Implement LLVM RTTI.
83178
static bool classof(const EvaluatorValue *e) {
@@ -93,12 +188,25 @@ struct ListValue : EvaluatorValue {
93188
struct MapValue : EvaluatorValue {
94189
MapValue(om::MapType type, DenseMap<Attribute, EvaluatorValuePtr> elements)
95190
: EvaluatorValue(type.getContext(), Kind::Map), type(type),
96-
elements(std::move(elements)) {}
191+
elements(std::move(elements)) {
192+
markFullyEvaluated();
193+
}
194+
195+
// Partially evaluated value.
196+
MapValue(om::MapType type)
197+
: EvaluatorValue(type.getContext(), Kind::Map), type(type) {}
97198

98199
const auto &getElements() const { return elements; }
200+
void setElements(DenseMap<Attribute, EvaluatorValuePtr> newElements) {
201+
elements = std::move(newElements);
202+
markFullyEvaluated();
203+
}
204+
205+
// Finalize the evaluator value.
206+
LogicalResult finalizeImpl();
99207

100208
/// Return the type of the value, which is a MapType.
101-
om::MapType getType() const { return type; }
209+
om::MapType getMapType() const { return type; }
102210

103211
/// Return an array of keys in the ascending order.
104212
ArrayAttr getKeys();
@@ -117,28 +225,48 @@ struct MapValue : EvaluatorValue {
117225
struct ObjectValue : EvaluatorValue {
118226
ObjectValue(om::ClassOp cls, ObjectFields fields)
119227
: EvaluatorValue(cls.getContext(), Kind::Object), cls(cls),
120-
fields(std::move(fields)) {}
228+
fields(std::move(fields)) {
229+
markFullyEvaluated();
230+
}
231+
232+
// Partially evaluated value.
233+
ObjectValue(om::ClassOp cls)
234+
: EvaluatorValue(cls.getContext(), Kind::Object), cls(cls) {}
235+
121236
om::ClassOp getClassOp() const { return cls; }
122237
const auto &getFields() const { return fields; }
123238

239+
void setFields(llvm::SmallDenseMap<StringAttr, EvaluatorValuePtr> newFields) {
240+
fields = std::move(newFields);
241+
markFullyEvaluated();
242+
}
243+
124244
/// Return the type of the value, which is a ClassType.
125-
om::ClassType getType() const {
245+
om::ClassType getObjectType() const {
126246
auto clsConst = const_cast<ClassOp &>(cls);
127247
return ClassType::get(clsConst.getContext(),
128248
FlatSymbolRefAttr::get(clsConst.getNameAttr()));
129249
}
130250

251+
Type getType() const { return getObjectType(); }
252+
131253
/// Implement LLVM RTTI.
132254
static bool classof(const EvaluatorValue *e) {
133255
return e->getKind() == Kind::Object;
134256
}
135257

136258
/// Get a field of the Object by name.
137259
FailureOr<EvaluatorValuePtr> getField(StringAttr field);
260+
FailureOr<EvaluatorValuePtr> getField(StringRef field) {
261+
return getField(StringAttr::get(getContext(), field));
262+
}
138263

139264
/// Get all the field names of the Object.
140265
ArrayAttr getFieldNames();
141266

267+
// Finalize the evaluator value.
268+
LogicalResult finalizeImpl();
269+
142270
private:
143271
om::ClassOp cls;
144272
llvm::SmallDenseMap<StringAttr, EvaluatorValuePtr> fields;
@@ -149,15 +277,33 @@ struct TupleValue : EvaluatorValue {
149277
using TupleElements = llvm::SmallVector<EvaluatorValuePtr>;
150278
TupleValue(TupleType type, TupleElements tupleElements)
151279
: EvaluatorValue(type.getContext(), Kind::Tuple), type(type),
152-
elements(std::move(tupleElements)) {}
280+
elements(std::move(tupleElements)) {
281+
markFullyEvaluated();
282+
}
283+
284+
// Partially evaluated value.
285+
TupleValue(TupleType type)
286+
: EvaluatorValue(type.getContext(), Kind::Tuple), type(type) {}
287+
288+
void setElements(TupleElements newElements) {
289+
elements = std::move(newElements);
290+
markFullyEvaluated();
291+
}
292+
293+
LogicalResult finalizeImpl() {
294+
for (auto &&value : elements)
295+
if (failed(finalizeEvaluatorValue(value)))
296+
return failure();
153297

298+
return success();
299+
}
154300
/// Implement LLVM RTTI.
155301
static bool classof(const EvaluatorValue *e) {
156302
return e->getKind() == Kind::Tuple;
157303
}
158304

159305
/// Return the type of the value, which is a TupleType.
160-
TupleType getType() const { return type; }
306+
TupleType getTupleType() const { return type; }
161307

162308
const TupleElements &getElements() const { return elements; }
163309

@@ -182,49 +328,81 @@ struct Evaluator {
182328
Evaluator(ModuleOp mod);
183329

184330
/// Instantiate an Object with its class name and actual parameters.
185-
FailureOr<std::shared_ptr<Object>>
331+
FailureOr<evaluator::EvaluatorValuePtr>
186332
instantiate(StringAttr className, ArrayRef<EvaluatorValuePtr> actualParams);
187333

188334
/// Get the Module this Evaluator is built from.
189335
mlir::ModuleOp getModule();
190336

337+
FailureOr<evaluator::EvaluatorValuePtr> getPartiallyEvaluatedValue(Type type);
338+
339+
using ActualParameters =
340+
SmallVectorImpl<std::shared_ptr<evaluator::EvaluatorValue>> *;
341+
342+
using ObjectKey = std::pair<Value, ActualParameters>;
343+
191344
private:
345+
bool isFullyEvaluated(Value value, ActualParameters key) {
346+
return isFullyEvaluated({value, key});
347+
}
348+
349+
bool isFullyEvaluated(ObjectKey key) {
350+
auto val = objects.lookup(key);
351+
return val && val->isFullyEvaluated();
352+
}
353+
354+
FailureOr<EvaluatorValuePtr> getOrCreateValue(Value value,
355+
ActualParameters actualParams);
356+
FailureOr<EvaluatorValuePtr>
357+
allocateObjectInstance(StringAttr clasName, ActualParameters actualParams);
358+
192359
/// Evaluate a Value in a Class body according to the small expression grammar
193360
/// described in the rationale document. The actual parameters are the values
194361
/// supplied at the current instantiation of the Class being evaluated.
195-
FailureOr<EvaluatorValuePtr>
196-
evaluateValue(Value value, ArrayRef<EvaluatorValuePtr> actualParams);
362+
FailureOr<EvaluatorValuePtr> evaluateValue(Value value,
363+
ActualParameters actualParams);
197364

198365
/// Evaluator dispatch functions for the small expression grammar.
199-
FailureOr<EvaluatorValuePtr>
200-
evaluateParameter(BlockArgument formalParam,
201-
ArrayRef<EvaluatorValuePtr> actualParams);
366+
FailureOr<EvaluatorValuePtr> evaluateParameter(BlockArgument formalParam,
367+
ActualParameters actualParams);
202368

369+
FailureOr<EvaluatorValuePtr> evaluateConstant(ConstantOp op,
370+
ActualParameters actualParams);
371+
/// Instantiate an Object with its class name and actual parameters.
203372
FailureOr<EvaluatorValuePtr>
204-
evaluateConstant(ConstantOp op, ArrayRef<EvaluatorValuePtr> actualParams);
205-
FailureOr<EvaluatorValuePtr>
206-
evaluateObjectInstance(ObjectOp op, ArrayRef<EvaluatorValuePtr> actualParams);
373+
evaluateObjectInstance(StringAttr className, ActualParameters actualParams,
374+
ObjectKey instanceObjectKey = {});
207375
FailureOr<EvaluatorValuePtr>
208-
evaluateObjectField(ObjectFieldOp op,
209-
ArrayRef<EvaluatorValuePtr> actualParams);
376+
evaluateObjectInstance(ObjectOp op, ActualParameters actualParams);
210377
FailureOr<EvaluatorValuePtr>
211-
evaluateListCreate(ListCreateOp op, ArrayRef<EvaluatorValuePtr> actualParams);
378+
evaluateObjectField(ObjectFieldOp op, ActualParameters actualParams);
212379
FailureOr<EvaluatorValuePtr>
213-
evaluateTupleCreate(TupleCreateOp op,
214-
ArrayRef<EvaluatorValuePtr> actualParams);
380+
evaluateListCreate(ListCreateOp op, ActualParameters actualParams);
215381
FailureOr<EvaluatorValuePtr>
216-
evaluateTupleGet(TupleGetOp op, ArrayRef<EvaluatorValuePtr> actualParams);
382+
evaluateTupleCreate(TupleCreateOp op, ActualParameters actualParams);
383+
FailureOr<EvaluatorValuePtr> evaluateTupleGet(TupleGetOp op,
384+
ActualParameters actualParams);
217385
FailureOr<evaluator::EvaluatorValuePtr>
218-
evaluateMapCreate(MapCreateOp op,
219-
ArrayRef<evaluator::EvaluatorValuePtr> actualParams);
386+
evaluateMapCreate(MapCreateOp op, ActualParameters actualParams);
387+
388+
FailureOr<ActualParameters>
389+
createParametersFromOperands(ValueRange range, ActualParameters actualParams);
220390

221391
/// The symbol table for the IR module the Evaluator was constructed with.
222392
/// Used to look up class definitions.
223393
SymbolTable symbolTable;
224394

225-
/// Object storage. Currently used for memoizing calls to
226-
/// evaluateObjectInstance. Further refinement is expected.
227-
DenseMap<Value, std::shared_ptr<evaluator::EvaluatorValue>> objects;
395+
/// This uniquely stores vectors that represent parameters.
396+
SmallVector<
397+
std::unique_ptr<SmallVector<std::shared_ptr<evaluator::EvaluatorValue>>>>
398+
actualParametersBuffers;
399+
400+
/// A worklist that tracks values which needs to be fully evaluated.
401+
std::queue<ObjectKey> worklist;
402+
403+
/// Evaluator value storage. Return an evaluator value for the given
404+
/// instantiation context (a pair of Value and parameters).
405+
DenseMap<ObjectKey, std::shared_ptr<evaluator::EvaluatorValue>> objects;
228406
};
229407

230408
/// Helper to enable printing objects in Diagnostics.

integration_test/Bindings/Python/dialects/om.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,3 +185,7 @@
185185
if isinstance(data, om.Object):
186186
object_dict[data] = field_name
187187
assert len(object_dict) == 2
188+
189+
obj = evaluator.instantiate("Test", 41)
190+
# CHECK: 41
191+
print(obj.field)

lib/CAPI/Dialect/OM.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,7 @@ OMEvaluatorValue omEvaluatorInstantiate(OMEvaluator evaluator,
108108
cppActualParams.push_back(unwrap(actualParams[i]));
109109

110110
// Invoke the Evaluator to instantiate the Object.
111-
FailureOr<std::shared_ptr<evaluator::ObjectValue>> result =
112-
cppEvaluator->instantiate(cppClassName, cppActualParams);
111+
auto result = cppEvaluator->instantiate(cppClassName, cppActualParams);
113112

114113
// If instantiation failed, return a null Object. A Diagnostic will be emitted
115114
// in this case.

0 commit comments

Comments
 (0)