Skip to content

Commit

Permalink
Basic generic function invocation started.
Browse files Browse the repository at this point in the history
  • Loading branch information
asoffer committed Oct 13, 2023
1 parent 44ce6b3 commit a429f79
Show file tree
Hide file tree
Showing 14 changed files with 118 additions and 74 deletions.
6 changes: 3 additions & 3 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")

http_archive(
name = "asoffer_nth",
urls = ["https://github.com/asoffer/nth/archive/1dece678fc9a4a4b4e1ea82dc71370482c2c5fcb.zip"],
strip_prefix = "nth-1dece678fc9a4a4b4e1ea82dc71370482c2c5fcb",
sha256 = "4a698880fab2a1533ccf8c031bf47a07e0da14e04e6f8782c6a9e306de04efaa",
urls = ["https://github.com/asoffer/nth/archive/c1b12edeb31d8732e76811e081d7cba3afdadcf4.zip"],
strip_prefix = "nth-c1b12edeb31d8732e76811e081d7cba3afdadcf4",
sha256 = "1673c58d3a553f76eed0816681f21ac88e5b09af98c6f042ef4cb6d155860e97",
)

http_archive(
Expand Down
4 changes: 4 additions & 0 deletions common/resources.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ struct Resources {
return identifiers.index(identifiers.insert(s).first);
}

std::string_view Identifier(size_t index) {
return identifiers.from_index(index);
}

// Values of string literals used in the program.
nth::flyweight_set<std::string> strings;

Expand Down
2 changes: 1 addition & 1 deletion ir/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ cc_library(
"@asoffer_jasmin//jasmin:execute",
"@asoffer_nth//nth/debug",
"@asoffer_nth//nth/debug/log",
"@asoffer_nth//nth/container:interval_set",
"@asoffer_nth//nth/container:interval_map",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:btree",
],
Expand Down
22 changes: 9 additions & 13 deletions ir/builtin_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ nth::NoDestructor<IrFunction> Function([] {
}());

nth::NoDestructor<IrFunction> Foreign([] {
IrFunction f(2, 1);
IrFunction f(1, 1);
f.append<jasmin::Return>();
return f;
}());

Expand All @@ -51,21 +52,16 @@ Module BuiltinModule(GlobalFunctionRegistry& registry) {
registry.Register(FunctionId(ModuleId::Builtin(), LocalFunctionId(next_id++)),
&*PrintFn);

m.Insert(resources.IdentifierIndex("foreign"),
{.qualified_type = type::QualifiedType::Constant(
type::GenericFunction(&*ForeignType)),
.value = {}});
m.Insert(
resources.IdentifierIndex("foreign"),
{.qualified_type = type::QualifiedType::Constant(
type::GenericFunction(type::Evaluation::CompileTime, &*ForeignType)),
.value = {&*Foreign}});
registry.Register(FunctionId(ModuleId::Builtin(), LocalFunctionId(next_id++)),
&*ForeignType);
registry.Register(FunctionId(ModuleId::Builtin(), LocalFunctionId(next_id++)),
&*Foreign);

m.Insert(
resources.IdentifierIndex("b2b"),
{.qualified_type = type::QualifiedType::Constant(type::Type_),
.value = {jasmin::Value(type::Type(type::Function(
type::Parameters(std::vector<type::ParametersType::Parameter>{
{.name = resources.IdentifierIndex(""), .type = type::Bool}}),
{type::Bool})))}});

m.Insert(resources.IdentifierIndex("function"),
{.qualified_type =
type::QualifiedType::Constant(type::Pattern(type::Type_)),
Expand Down
3 changes: 2 additions & 1 deletion ir/deserialize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ bool Deserializer::Deserialize(ModuleProto const& proto, Module& module) {
module.add_function(function.parameters(), function.returns());
}

if (not DeserializeFunction(proto, proto.initializer(), module.initializer())) {
if (not DeserializeFunction(proto, proto.initializer(),
module.initializer())) {
return false;
}

Expand Down
35 changes: 19 additions & 16 deletions ir/emit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,11 @@ void HandleParseTreeNodeTypeLiteral(ParseTree::Node::Index index,
EmitContext& context) {
auto node = context.Node(index);
switch (node.token.kind()) {
case Token::Kind::Bool:
context.function_stack.back()->append<jasmin::Push>(type::Bool);
break;
#define IC_XMACRO_TOKEN_KIND_BUILTIN_TYPE(kind, symbol, spelling) \
case Token::Kind::kind: \
context.function_stack.back()->append<jasmin::Push>(type::symbol); \
break;
#include "lexer/token_kind.xmacro.h"
default: NTH_UNREACHABLE();
}
}
Expand Down Expand Up @@ -113,19 +115,13 @@ void HandleParseTreeNodeExpressionGroup(ParseTree::Node::Index, EmitContext&) {

void HandleParseTreeNodeMemberExpression(ParseTree::Node::Index index,
EmitContext& context) {
// TODO: Once we have interval_map this will be easier.
jasmin::ValueStack const* vs = nullptr;
for (auto const& [range, value_stack] : context.constants) {
if (range.upper_bound() == index) {
vs = &value_stack;
break;
}
}
NTH_REQUIRE((v.harden), vs != nullptr);
auto const* mapped_range = context.constants.mapped_range(index);
NTH_REQUIRE((v.harden), mapped_range != nullptr);
context.function_stack.back()->append<jasmin::Drop>(1);

ModuleId module_id;
bool successfully_deserialized =
IcarusDeserializeValue(std::span(vs->begin(), vs->end()), module_id);
IcarusDeserializeValue(mapped_range->second.value_span(), module_id);
NTH_REQUIRE((v.harden), successfully_deserialized);

auto symbol = context.module(module_id).Lookup(
Expand Down Expand Up @@ -167,6 +163,7 @@ void EmitNonConstant(nth::interval<ParseTree::Node::Index> node_range,

void EmitContext::Push(jasmin::Value v, type::Type t) {
switch (t.kind()) {
case type::Type::Kind::GenericFunction:
case type::Type::Kind::Function: {
function_stack.back()->append<PushFunction>(v);
} break;
Expand All @@ -178,11 +175,13 @@ void EmitContext::Push(jasmin::Value v, type::Type t) {

void EmitIr(nth::interval<ParseTree::Node::Index> node_range, EmitContext& context) {
ParseTree::Node::Index start = node_range.lower_bound();
for (auto const& [range, value_stack] : context.constants) {
for (auto const& [range, constant] : context.constants.mapped_intervals()) {
if (range.lower_bound() < start) { continue; }
EmitNonConstant(nth::interval(start, range.lower_bound()), context);
// TODO: This type is wrong.
for (jasmin::Value const& v : value_stack) { context.Push(v, type::Bool); }
for (jasmin::Value const& v : constant.value_span()) {
context.Push(v, type::Bool);
}
start = range.upper_bound();
}
EmitNonConstant(nth::interval(start, node_range.upper_bound()), context);
Expand All @@ -191,11 +190,15 @@ void EmitIr(nth::interval<ParseTree::Node::Index> node_range, EmitContext& conte

void EmitContext::Evaluate(nth::interval<ParseTree::Node::Index> subtree,
jasmin::ValueStack& value_stack) {
jasmin::ValueStack vs;
IrFunction f(0, 1);
function_stack.push_back(&f);
EmitIr(subtree, *this);
f.append<jasmin::Return>();
jasmin::Execute(f, value_stack);
jasmin::Execute(f, vs);
for (jasmin::Value v : vs) { value_stack.push(v); }
constants.insert_or_assign(
subtree, ComputedConstant(subtree.upper_bound() - 1, std::move(vs)));
function_stack.pop_back();
}

Expand Down
42 changes: 26 additions & 16 deletions ir/emit.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "ir/module_id.h"
#include "jasmin/value_stack.h"
#include "nth/base/attributes.h"
#include "nth/container/interval_map.h"
#include "parser/parse_tree.h"
#include "type/type.h"

Expand Down Expand Up @@ -47,25 +48,34 @@ struct EmitContext {
std::vector<IrFunction*> function_stack;
absl::flat_hash_map<ParseTree::Node::Index, size_t> rotation_count;

struct Compare {
bool operator()(nth::interval<ParseTree::Node::Index> const& lhs,
nth::interval<ParseTree::Node::Index> const& rhs) const {
auto const& [lhs_l, lhs_u] = lhs;
auto const& [rhs_l, rhs_u] = rhs;
if (lhs_l < rhs_l) { return true; }
if (lhs_l > rhs_l) { return false; }
return lhs_u < rhs_u;
struct ComputedConstant {
explicit ComputedConstant(ParseTree::Node::Index index,
jasmin::ValueStack value)
: index_(index), value_(std::move(value)) {}

friend bool operator==(ComputedConstant const& lhs,
ComputedConstant const& rhs) {
return lhs.index_ == rhs.index_;
}

friend bool operator!=(ComputedConstant const& lhs,
ComputedConstant const& rhs) {
return not(lhs == rhs);
}

std::span<jasmin::Value const> value_span() const {
return std::span<jasmin::Value const>(value_.begin(), value_.end());
}

private:
ParseTree::Node::Index index_;
jasmin::ValueStack value_;
};

// Indices covering subtree roots which were required to be constant evaluated
// in order to type-check their parent, mapped to their corresponding constant
// value.
//
// TODO: This should really be it's own interval map type.
absl::btree_map<nth::interval<ParseTree::Node::Index>, jasmin::ValueStack,
Compare>
constants;
// Maps node indices to the constant value associated with the computation for
// the largest subtree containing it whose constant value has been computed
// thus far.
nth::interval_map<ParseTree::Node::Index, ComputedConstant> constants;
DependentModules const& modules;
};

Expand Down
25 changes: 14 additions & 11 deletions ir/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,6 @@ void HandleParseTreeNodeExpressionPrecedenceGroup(
case Token::Kind::MinusGreater: {
auto node = context.Node(index);
if (node.child_count != 2) {
for (auto const& c : context.Children(index)) { NTH_LOG("{}") <<= {c}; }
NTH_LOG("{}")<<={node.child_count};
NTH_REQUIRE(node.child_count != -1);
diag.Consume({
diag::Header(diag::MessageKind::Error),
Expand Down Expand Up @@ -177,7 +175,6 @@ void HandleParseTreeNodeMemberExpression(ParseTree::Node::Index index,
void HandleParseTreeNodeCallExpression(ParseTree::Node::Index index,
IrContext& context,
diag::DiagnosticConsumer& diag) {
auto& argument_width_count = context.emit.rotation_count[index];
auto node = context.Node(index);
auto invocable_type =
context.type_stack[context.type_stack.size() - node.child_count];
Expand All @@ -187,6 +184,7 @@ void HandleParseTreeNodeCallExpression(ParseTree::Node::Index index,
// TODO: Properly implement function call type-checking.
if (parameters.size() == node.child_count - 1) {
auto type_iter = context.type_stack.rbegin();
auto& argument_width_count = context.emit.rotation_count[index];
for (size_t i = 0; i < parameters.size(); ++i) {
argument_width_count += type::JasminSize(type_iter->type());
++type_iter;
Expand All @@ -207,15 +205,20 @@ void HandleParseTreeNodeCallExpression(ParseTree::Node::Index index,
context.Node(*iter).kind !=
ParseTree::Node::Kind::InvocationArgumentStart;
++iter) {
context.emit.Evaluate(context.emit.tree.subtree_range(*iter),
value_stack);
nth::interval range = context.emit.tree.subtree_range(*iter);
context.emit.Evaluate(range, value_stack);
}
auto g = invocable_type.type().AsGenericFunction();
jasmin::Execute(*static_cast<IrFunction const*>(g.data()), value_stack);
auto t = value_stack.pop<type::Type>();
context.type_stack.push_back(type::QualifiedType::Constant(t));
if (g.evaluation() == type::Evaluation::CompileTime) {
jasmin::ValueStack value_stack;
value_stack.push(t);
context.emit.constants.insert_or_assign(
context.emit.tree.subtree_range(index),
EmitContext::ComputedConstant(index, std::move(value_stack)));
}

jasmin::Execute(*static_cast<IrFunction const*>(
invocable_type.type().AsGenericFunction().data()),
value_stack);
context.type_stack.push_back(
type::QualifiedType::Constant(value_stack.pop<type::Type>()));
} else {
NTH_UNIMPLEMENTED("{}") <<= {node};
}
Expand Down
6 changes: 2 additions & 4 deletions ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,11 @@ struct IrContext {
template <typename T>
std::optional<T> EvaluateAs(ParseTree::Node::Index subtree_root_index) {
T result;
nth::interval range = emit.tree.subtree_range(subtree_root_index);
nth::interval range = emit.tree.subtree_range(subtree_root_index);
jasmin::ValueStack value_stack;
emit.Evaluate(range, value_stack);
auto [iter, inserted] = emit.constants.try_emplace(range, std::move(value_stack));
NTH_REQUIRE((v.harden), inserted);
if (IcarusDeserializeValue(
std::span(iter->second.begin(), iter->second.end()), result)) {
std::span(value_stack.begin(), value_stack.end()), result)) {
return result;
} else {
return std::nullopt;
Expand Down
4 changes: 3 additions & 1 deletion ir/module.proto
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ package ic;
message InstructionProto {
enum OpCode {
// Note: This needs to be kept in sync with the `jasmin::InstructionSet`
// defined in "ir/module.h"
// defined in "ir/function.h"
CALL = 0;
JUMP = 1;
JUMP_IF = 2;
Expand Down Expand Up @@ -47,4 +47,6 @@ message ModuleProto {
repeated ForeignFunctionProto foreign_functions = 4;

type.TypeSystem type_system = 5;

map<uint32, string> identifiers = 6;
}
2 changes: 2 additions & 0 deletions toolchain/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -46,5 +46,7 @@ cc_binary(
"@asoffer_nth//nth/io:file",
"@asoffer_nth//nth/io:file_path",
"@asoffer_nth//nth/process:exit_code",
"@com_google_absl//absl/debugging:failure_signal_handler",
"@com_google_absl//absl/debugging:symbolize",
],
)
10 changes: 9 additions & 1 deletion toolchain/run_bytecode.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
#include <optional>
#include <string>

#include "absl/debugging/failure_signal_handler.h"
#include "absl/debugging/symbolize.h"
#include "common/string.h"
#include "diagnostics/consumer/streaming.h"
#include "diagnostics/message.h"
Expand All @@ -23,6 +25,10 @@ namespace ic {
namespace {

nth::exit_code Run(nth::FlagValueSet flags, std::span<std::string_view const>) {
absl::InitializeSymbolizer("");
absl::FailureSignalHandlerOptions opts;
absl::InstallFailureSignalHandler(opts);

auto const& input = flags.get<nth::file_path>("input");
std::ifstream in(input.path());

Expand Down Expand Up @@ -54,14 +60,16 @@ nth::exit_code Run(nth::FlagValueSet flags, std::span<std::string_view const>) {
}

if (not proto.ParseFromIstream(&in) or not d.Deserialize(proto, module)) {
NTH_LOG((v.debug), "{}") <<= {proto.DebugString()};
consumer.Consume({
diag::Header(diag::MessageKind::Error),
diag::Text(
InterpolateString<"Failed to parse the moudle content from {}.">(
InterpolateString<"Failed to parse the module content from {}.">(
input)),
});
return nth::exit_code::generic_error;
}
NTH_LOG((v.debug), "{}") <<= {proto.DebugString()};

jasmin::ValueStack value_stack;
jasmin::Execute(module.initializer(), value_stack);
Expand Down
13 changes: 9 additions & 4 deletions type/type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ nth::NoDestructor<nth::flyweight_set<Type>> slice_element_types;
nth::NoDestructor<nth::flyweight_set<Type>> pointee_types;
nth::NoDestructor<nth::flyweight_set<Type>> buffer_pointee_types;
nth::NoDestructor<nth::flyweight_set<Type>> pattern_types;
nth::NoDestructor<nth::flyweight_set<void const*>> generic_function_types;
nth::NoDestructor<nth::flyweight_set<std::pair<void const*, Evaluation>>>
generic_function_types;

} // namespace

Expand Down Expand Up @@ -72,9 +73,9 @@ PatternType Pattern(Type t) {
return PatternType(pattern_types->index(pattern_types->insert(t).first));
}

GenericFunctionType GenericFunction(void const* fn) {
GenericFunctionType GenericFunction(Evaluation e, void const* fn) {
return GenericFunctionType(
generic_function_types->index(generic_function_types->insert(fn).first));
generic_function_types->index(generic_function_types->insert(std::pair(fn, e)).first));
}

Type SliceType::element_type() const {
Expand Down Expand Up @@ -105,7 +106,11 @@ std::vector<Type> const& FunctionType::returns() const {
}

void const* GenericFunctionType::data() const {
return generic_function_types->from_index(BasicType::data());
return generic_function_types->from_index(BasicType::data()).first;
}

Evaluation GenericFunctionType::evaluation() const {
return generic_function_types->from_index(BasicType::data()).second;
}

size_t JasminSize(Type t) {
Expand Down
Loading

0 comments on commit a429f79

Please sign in to comment.