Add the ast layer to the jit

This commit is contained in:
Sameer Rahmani 2021-11-19 00:41:28 +00:00
parent 5a57f7c98b
commit b6ac7f3f5f
9 changed files with 173 additions and 91 deletions

View File

@ -22,7 +22,6 @@
#include "serene/diagnostics.h"
#include "serene/environment.h"
#include "serene/export.h"
#include "serene/jit.h"
#include "serene/jit/engine.h"
#include "serene/namespace.h"
#include "serene/passes.h"
@ -40,6 +39,8 @@
#include <memory>
#define DEFAULT_NS_NAME "serene.user"
namespace serene {
namespace reader {
@ -127,8 +128,9 @@ public:
// We need to create one empty namespace, so that the JIT can
// start it's operation.
auto ns = makeNamespace(*this, "serene.user", llvm::None);
auto ns = makeNamespace(*this, DEFAULT_NS_NAME, llvm::None);
insertNS(ns);
// TODO: Get the crash report path dynamically from the cli
// pm.enableCrashReproducerGeneration("/home/lxsameer/mlir.mlir");
@ -151,8 +153,12 @@ public:
};
static std::unique_ptr<SereneContext> make() {
auto ctx = std::make_unique<SereneContext>();
auto maybeJIT = serene::jit::makeSereneJIT(*ctx);
auto ctx = std::make_unique<SereneContext>();
auto *ns = ctx->getNS(DEFAULT_NS_NAME);
assert(ns != nullptr && "Default ns doesn't exit!");
auto maybeJIT = serene::jit::makeSereneJIT(*ns);
if (!maybeJIT) {
// TODO: Raise an error here

View File

@ -90,8 +90,8 @@ class SereneJIT {
/// and generate LLVM IR
orc::IRTransformLayer transformLayer;
/// The AST Layer reads and import the Serene Ast directly to the JIT
// SereneAstLayer astLayer;
// The AST Layer reads and import the Serene Ast directly to the JIT
AstLayer astLayer;
/// NS layer is responsible for adding namespace to the JIT by name.
/// It will import the entire namespace.
@ -129,9 +129,10 @@ class SereneJIT {
return std::move(tsm);
}
Namespace &currentNS;
public:
SereneJIT(serene::SereneContext &ctx,
std::unique_ptr<orc::ExecutionSession> es,
SereneJIT(Namespace &entryNS, std::unique_ptr<orc::ExecutionSession> es,
std::unique_ptr<orc::EPCIndirectionUtils> epciu,
orc::JITTargetMachineBuilder jtmb, llvm::DataLayout &&dl);
@ -145,6 +146,8 @@ public:
}
};
Namespace &getCurrentNS() { return currentNS; }
const llvm::DataLayout &getDataLayout() const { return dl; }
orc::JITDylib &getMainJITDylib() { return mainJD; }
@ -161,14 +164,15 @@ public:
llvm::Error addNS(llvm::StringRef nsname,
orc::ResourceTrackerSP rt = nullptr);
llvm::Error addAst(exprs::Ast &ast, orc::ResourceTrackerSP rt = nullptr);
llvm::Expected<llvm::JITEvaluatedSymbol> lookup(llvm::StringRef name) {
JIT_LOG("Looking up symbol: " + name);
return es->lookup({&mainJD}, mangler(name.str()));
}
};
llvm::Expected<std::unique_ptr<SereneJIT>>
makeSereneJIT(serene::SereneContext &ctx);
llvm::Expected<std::unique_ptr<SereneJIT>> makeSereneJIT(Namespace &ns);
}; // namespace jit
}; // namespace serene

View File

@ -19,6 +19,7 @@
#ifndef SERENE_JIT_LAYERS_H
#define SERENE_JIT_LAYERS_H
#include "serene/namespace.h"
#include "serene/reader/location.h"
#include "serene/utils.h"
@ -47,16 +48,14 @@ using Ast = std::vector<Node>;
namespace jit {
class SereneAstLayer;
class AstLayer;
/// This will compile the ast to llvm ir.
llvm::orc::ThreadSafeModule compileAst(serene::SereneContext &ctx,
exprs::Ast &ast);
llvm::orc::ThreadSafeModule compileAst(Namespace &ns, exprs::Ast &ast);
class SerenAstMaterializationUnit : public orc::MaterializationUnit {
class AstMaterializationUnit : public orc::MaterializationUnit {
public:
SerenAstMaterializationUnit(SereneContext &ctx, SereneAstLayer &l,
exprs::Ast &ast);
AstMaterializationUnit(Namespace &ns, AstLayer &l, exprs::Ast &ast);
llvm::StringRef getName() const override {
return "SereneAstMaterializationUnit";
@ -73,51 +72,42 @@ private:
llvm_unreachable("Serene functions are not overridable");
}
serene::SereneContext &ctx;
SereneAstLayer &astLayer;
Namespace &ns;
AstLayer &astLayer;
exprs::Ast &ast;
};
// class SereneAstLayer {
// SereneContext &ctx;
// orc::IRLayer &baseLayer;
// orc::MangleAndInterner &mangler;
class AstLayer {
orc::IRLayer &baseLayer;
orc::MangleAndInterner &mangler;
// const llvm::DataLayout &dl;
public:
AstLayer(orc::IRLayer &baseLayer, orc::MangleAndInterner &mangler)
: baseLayer(baseLayer), mangler(mangler){};
// public:
// SereneAstLayer(SereneContext &ctx, orc::IRLayer &baseLayer,
// orc::MangleAndInterner &mangler, const llvm::DataLayout &dl)
// : ctx(ctx), baseLayer(baseLayer), mangler(mangler), dl(dl){};
llvm::Error add(orc::ResourceTrackerSP &rt, Namespace &ns, exprs::Ast &ast) {
return rt->getJITDylib().define(
std::make_unique<AstMaterializationUnit>(ns, *this, ast), rt);
}
// llvm::Error add(orc::ResourceTrackerSP &rt, exprs::Ast &ast) {
// return rt->getJITDylib().define(
// std::make_unique<SerenAstMaterializationUnit>(ctx, *this, ast), rt);
// }
void emit(std::unique_ptr<orc::MaterializationResponsibility> mr,
Namespace &ns, exprs::Ast &e) {
// void emit(std::unique_ptr<orc::MaterializationResponsibility> mr,
// exprs::Ast &e) {
// baseLayer.emit(std::move(mr), compileAst(ctx, e));
// }
baseLayer.emit(std::move(mr), compileAst(ns, e));
}
// orc::SymbolFlagsMap getInterface(exprs::Ast &e) {
// orc::SymbolFlagsMap Symbols;
// Symbols[mangler(e.getName())] = llvm::JITSymbolFlags(
// llvm::JITSymbolFlags::Exported | llvm::JITSymbolFlags::Callable);
// return Symbols;
// }
// };
orc::SymbolFlagsMap getInterface(Namespace &ns, exprs::Ast &e);
};
/// NS Layer ==================================================================
class NSLayer;
/// This will compile the NS to llvm ir.
llvm::orc::ThreadSafeModule compileNS(serene::SereneContext &ctx,
serene::Namespace &ns);
llvm::orc::ThreadSafeModule compileNS(Namespace &ns);
class NSMaterializationUnit : public orc::MaterializationUnit {
public:
NSMaterializationUnit(SereneContext &ctx, NSLayer &l, serene::Namespace &ns);
NSMaterializationUnit(NSLayer &l, Namespace &ns);
llvm::StringRef getName() const override { return "NSMaterializationUnit"; }
@ -129,13 +119,12 @@ private:
const orc::SymbolStringPtr &sym) override {
UNUSED(jd);
UNUSED(sym);
UNUSED(ctx);
llvm_unreachable("Serene functions are not overridable");
// TODO: Check the ctx to see whether we need to remove the sym or not
}
serene::SereneContext &ctx;
NSLayer &nsLayer;
serene::Namespace &ns;
Namespace &ns;
};
/// NS Layer is responsible for adding namespaces to the JIT
@ -146,7 +135,7 @@ class NSLayer {
const llvm::DataLayout &dl;
public:
NSLayer(serene::SereneContext &ctx, orc::IRLayer &baseLayer,
NSLayer(SereneContext &ctx, orc::IRLayer &baseLayer,
orc::MangleAndInterner &mangler, const llvm::DataLayout &dl)
: ctx(ctx), baseLayer(baseLayer), mangler(mangler), dl(dl){};
@ -165,7 +154,7 @@ public:
// the data layout all the time
UNUSED(dl);
LAYER_LOG("Emit namespace");
baseLayer.emit(std::move(mr), compileNS(ctx, ns));
baseLayer.emit(std::move(mr), compileNS(ns));
}
orc::SymbolFlagsMap getInterface(serene::Namespace &ns);

View File

@ -124,6 +124,8 @@ public:
errors::OptionalErrors addTree(exprs::Ast &ast);
exprs::Ast &getTree();
const std::vector<llvm::StringRef> &getSymList() { return symbolList; };
/// Increase the function counter by one
uint nextFnCounter();
@ -132,12 +134,16 @@ public:
// TODO: Fix the return type and use a `llvm::Optional` instead
/// Generate and return a MLIR ModuleOp tha contains the IR of the namespace
/// with respect to the compilation phase
MaybeModuleOp generate();
MaybeModuleOp generate(unsigned offset = 0);
/// Compile the namespace to a llvm module. It will call the
/// `generate` method of the namespace to generate the IR.
MaybeModule compileToLLVM();
/// Compile the given namespace from the given \p offset of AST till the end
/// of the trees.
MaybeModule compileToLLVMFromOffset(unsigned offset);
/// Run all the passes specified in the context on the given MLIR ModuleOp.
mlir::LogicalResult runPasses(mlir::ModuleOp &m);

View File

@ -34,7 +34,7 @@ static void handleLazyCallThroughError() {
// TODO: terminate ?
}
SereneJIT::SereneJIT(serene::SereneContext &ctx,
SereneJIT::SereneJIT(Namespace &entryNS,
std::unique_ptr<orc::ExecutionSession> es,
std::unique_ptr<orc::EPCIndirectionUtils> epciu,
orc::JITTargetMachineBuilder jtmb, llvm::DataLayout &&dl)
@ -48,13 +48,10 @@ SereneJIT::SereneJIT(serene::SereneContext &ctx,
*this->es, objectLayer,
std::make_unique<orc::ConcurrentIRCompiler>(std::move(jtmb))),
transformLayer(*this->es, compileLayer, optimizeModule),
// TODO: Change compileOnDemandLayer to use an optimization layer
// as the parent
// compileOnDemandLayer(
// *this->es, compileLayer, this->epciu->getLazyCallThroughManager(),
// [this] { return this->epciu->createIndirectStubsManager(); }),
nsLayer(ctx, transformLayer, mangler, dl),
mainJD(this->es->createBareJITDylib(ctx.getCurrentNS().name)), ctx(ctx) {
astLayer(transformLayer, mangler),
nsLayer(entryNS.getContext(), transformLayer, mangler, dl),
mainJD(this->es->createBareJITDylib(entryNS.name)),
ctx(entryNS.getContext()), currentNS(entryNS) {
UNUSED(this->ctx);
mainJD.addGenerator(
cantFail(orc::DynamicLibrarySearchGenerator::GetForCurrentProcess(
@ -74,8 +71,15 @@ llvm::Error SereneJIT::addNS(llvm::StringRef nsname,
return nsLayer.add(rt, nsname);
};
llvm::Expected<std::unique_ptr<SereneJIT>>
makeSereneJIT(serene::SereneContext &ctx) {
llvm::Error SereneJIT::addAst(exprs::Ast &ast, orc::ResourceTrackerSP rt) {
if (!rt) {
rt = mainJD.getDefaultResourceTracker();
}
return astLayer.add(rt, getCurrentNS(), ast);
};
llvm::Expected<std::unique_ptr<SereneJIT>> makeSereneJIT(Namespace &ns) {
auto epc = orc::SelfExecutorProcessControl::Create();
if (!epc) {
return epc.takeError();
@ -102,7 +106,7 @@ makeSereneJIT(serene::SereneContext &ctx) {
return dl.takeError();
}
return std::make_unique<SereneJIT>(ctx, std::move(es), std::move(*epciu),
return std::make_unique<SereneJIT>(ns, std::move(es), std::move(*epciu),
std::move(jtmb), std::move(*dl));
};
} // namespace serene::jit

View File

@ -25,30 +25,79 @@
#include <llvm/ExecutionEngine/Orc/ThreadSafeModule.h>
#include <llvm/Support/Error.h> // for report_fatal_error
#include <algorithm>
namespace serene::jit {
// llvm::orc::ThreadSafeModule compileAst(serene::SereneContext &ctx,
// exprs::Ast &ast){
llvm::orc::ThreadSafeModule compileAst(Namespace &ns, exprs::Ast &ast) {
// };
assert(ns.getTree().size() < ast.size() && "Did you add the ast to the NS?");
// SerenAstMaterializationUnit::SerenAstMaterializationUnit(
// serene::SereneContext &ctx, SereneAstLayer &l, exprs::Ast &ast)
// : MaterializationUnit(l.getInterface(ast), nullptr), ctx(ctx),
// astLayer(l),
// ast(ast){};
LAYER_LOG("Compile in context of namespace: " + ns.name);
unsigned offset = ns.getTree().size() - ast.size();
// void SerenAstMaterializationUnit::materialize(
// std::unique_ptr<orc::MaterializationResponsibility> r) {
// astLayer.emit(std::move(r), ast);
// }
auto maybeModule = ns.compileToLLVMFromOffset(offset);
if (!maybeModule) {
// TODO: Handle failure
llvm::report_fatal_error("Couldn't compile lazily JIT'd function");
}
return std::move(maybeModule.getValue());
};
AstMaterializationUnit::AstMaterializationUnit(Namespace &ns, AstLayer &l,
exprs::Ast &ast)
: MaterializationUnit(l.getInterface(ns, ast), nullptr), ns(ns),
astLayer(l), ast(ast){};
void AstMaterializationUnit::materialize(
std::unique_ptr<orc::MaterializationResponsibility> r) {
astLayer.emit(std::move(r), ns, ast);
}
orc::SymbolFlagsMap AstLayer::getInterface(Namespace &ns, exprs::Ast &e) {
orc::SymbolFlagsMap Symbols;
auto symList = ns.getSymList();
unsigned index = symList.size();
// This probably will change symList
auto err = ns.addTree(e);
if (err) {
// TODO: Fix this by a call to diag engine or return the err
llvm::outs() << "Fixme: semantic err\n";
return Symbols;
}
auto &env = ns.getRootEnv();
auto populateTableFn = [&env, this, &Symbols](auto name) {
auto flags = llvm::JITSymbolFlags::Exported;
auto maybeExpr = env.lookup(name.str());
if (!maybeExpr) {
LAYER_LOG("Skiping '" + name + "' symbol");
return;
}
auto expr = maybeExpr.getValue();
if (expr->getType() == exprs::ExprType::Fn) {
flags = flags | llvm::JITSymbolFlags::Callable;
}
auto mangledSym = this->mangler(name);
LAYER_LOG("Mangle symbol for: " + name + " = " << mangledSym);
Symbols[mangledSym] = llvm::JITSymbolFlags(flags);
};
std::for_each(symList.begin() + index, symList.end(), populateTableFn);
return Symbols;
}
/// NS Layer ==================================================================
llvm::orc::ThreadSafeModule compileNS(serene::SereneContext &ctx,
serene::Namespace &ns) {
UNUSED(ctx);
llvm::orc::ThreadSafeModule compileNS(Namespace &ns) {
LAYER_LOG("Compile namespace: " + ns.name);
auto maybeModule = ns.compileToLLVM();
@ -61,10 +110,8 @@ llvm::orc::ThreadSafeModule compileNS(serene::SereneContext &ctx,
return std::move(maybeModule.getValue());
};
NSMaterializationUnit::NSMaterializationUnit(SereneContext &ctx, NSLayer &l,
serene::Namespace &ns)
: MaterializationUnit(l.getInterface(ns), nullptr), ctx(ctx), nsLayer(l),
ns(ns){};
NSMaterializationUnit::NSMaterializationUnit(NSLayer &l, Namespace &ns)
: MaterializationUnit(l.getInterface(ns), nullptr), nsLayer(l), ns(ns){};
void NSMaterializationUnit::materialize(
std::unique_ptr<orc::MaterializationResponsibility> r) {
@ -89,7 +136,7 @@ llvm::Error NSLayer::add(orc::ResourceTrackerSP &rt, llvm::StringRef nsname,
LAYER_LOG("Add the materialize unit for: " + nsname);
return rt->getJITDylib().define(
std::make_unique<NSMaterializationUnit>(ctx, *this, *ns), rt);
std::make_unique<NSMaterializationUnit>(*this, *ns), rt);
}
orc::SymbolFlagsMap NSLayer::getInterface(serene::Namespace &ns) {
@ -104,8 +151,8 @@ orc::SymbolFlagsMap NSLayer::getInterface(serene::Namespace &ns) {
flags = flags | llvm::JITSymbolFlags::Callable;
}
auto mangledSym = mangler(k.getFirst());
LAYER_LOG("Mangle symbol for: " + k.getFirst() + " = " << mangledSym);
auto mangledSym = mangler(name);
LAYER_LOG("Mangle symbol for: " + name + " = " << mangledSym);
Symbols[mangledSym] = llvm::JITSymbolFlags(flags);
}

View File

@ -117,18 +117,21 @@ uint Namespace::nextFnCounter() { return fn_counter++; };
SereneContext &Namespace::getContext() { return this->ctx; };
MaybeModuleOp Namespace::generate() {
MaybeModuleOp Namespace::generate(unsigned offset) {
mlir::OpBuilder builder(&ctx.mlirContext);
// TODO: Fix the unknown location by pointing to the `ns` form
auto module = mlir::ModuleOp::create(builder.getUnknownLoc(),
llvm::Optional<llvm::StringRef>(name));
auto treeSize = getTree().size();
// Walk the AST and call the `generateIR` function of each node.
// Since nodes will have access to the a reference of the
// namespace they can use the builder and keep adding more
// operations to the module via the builder
for (auto &x : getTree()) {
x->generateIR(*this, module);
for (unsigned i = offset; i < treeSize; ++i) {
auto &node = getTree()[i];
node->generateIR(*this, module);
}
if (mlir::failed(mlir::verify(module))) {
@ -182,6 +185,22 @@ MaybeModule Namespace::compileToLLVM() {
return llvm::None;
};
MaybeModule Namespace::compileToLLVMFromOffset(unsigned offset) {
auto maybeModule = generate(offset);
if (!maybeModule) {
NAMESPACE_LOG("IR generation failed for '" << name << "'");
return llvm::None;
}
if (ctx.getTargetPhase() >= CompilationPhase::IR) {
mlir::ModuleOp module = maybeModule.getValue().get();
return ::serene::slir::compileToLLVMIR(ctx, module);
}
return llvm::None;
};
Namespace::~Namespace(){};
NSPtr makeNamespace(SereneContext &ctx, llvm::StringRef name,

View File

@ -127,6 +127,14 @@ SERENE_EXPORT exprs::MaybeNode eval(SereneContext &ctx, exprs::Ast &input) {
auto *f = (int (*)())(intptr_t)sym.getAddress();
f();
err = ctx.jit->addAst(input);
if (err) {
llvm::errs() << err;
auto e = errors::makeErrorTree(loc, errors::NSLoadError);
return exprs::makeErrorNode(loc, errors::NSLoadError);
}
return exprs::make<exprs::Number>(loc, "4", false, false);
};

View File

@ -52,8 +52,7 @@ int main(int argc, char *argv[]) {
llvm::outs() << banner << art;
auto ctx = makeSereneContext();
auto userNS = makeNamespace(*ctx, "user", llvm::None);
auto ctx = makeSereneContext();
applySereneCLOptions(*ctx);
@ -74,7 +73,7 @@ int main(int argc, char *argv[]) {
// Read line
std::string line;
std::string result;
std::string prompt = ctx->getCurrentNS().name + "> ";
std::string prompt = ctx->jit->getCurrentNS().name + "> ";
auto quit = linenoise::Readline(prompt.c_str(), line);