Add the Halley JIT and migrate to it

This commit is contained in:
Sameer Rahmani 2021-12-30 13:52:33 +00:00
parent 123a3e8d4f
commit c41c91b335
11 changed files with 336 additions and 251 deletions

View File

@ -22,7 +22,7 @@
#include "serene/diagnostics.h" #include "serene/diagnostics.h"
#include "serene/environment.h" #include "serene/environment.h"
#include "serene/export.h" #include "serene/export.h"
#include "serene/jit/engine.h" #include "serene/jit/halley.h"
#include "serene/namespace.h" #include "serene/namespace.h"
#include "serene/passes.h" #include "serene/passes.h"
#include "serene/slir/dialect.h" #include "serene/slir/dialect.h"
@ -31,6 +31,7 @@
#include <llvm/ADT/None.h> #include <llvm/ADT/None.h>
#include <llvm/ADT/Optional.h> #include <llvm/ADT/Optional.h>
#include <llvm/ADT/StringRef.h> #include <llvm/ADT/StringRef.h>
#include <llvm/ADT/Triple.h>
#include <llvm/IR/LLVMContext.h> #include <llvm/IR/LLVMContext.h>
#include <llvm/Support/Host.h> #include <llvm/Support/Host.h>
#include <mlir/Dialect/StandardOps/IR/Ops.h> #include <mlir/Dialect/StandardOps/IR/Ops.h>
@ -68,7 +69,10 @@ enum class CompilationPhase {
class SERENE_EXPORT SereneContext { class SERENE_EXPORT SereneContext {
struct Options { struct Options {
/// Whether to use colors for the output or not /// Whether to use colors for the output or not
bool withColors = true; bool withColors = true;
bool JITenableObjectCache = true;
bool JITenableGDBNotificationListener = true;
bool JITenablePerfNotificationListener = true;
Options() = default; Options() = default;
}; };
@ -89,7 +93,7 @@ public:
std::unique_ptr<DiagnosticEngine> diagEngine; std::unique_ptr<DiagnosticEngine> diagEngine;
std::unique_ptr<serene::jit::SereneJIT> jit; std::unique_ptr<serene::jit::Halley> jit;
/// The source manager is responsible for loading namespaces and practically /// The source manager is responsible for loading namespaces and practically
/// managing the source code in form of memory buffers. /// managing the source code in form of memory buffers.
@ -120,6 +124,11 @@ public:
/// return a pointer to it or a `nullptr` in it doesn't exist. /// return a pointer to it or a `nullptr` in it doesn't exist.
Namespace *getNS(llvm::StringRef ns_name); Namespace *getNS(llvm::StringRef ns_name);
/// Lookup and return a shared pointer to the given \p ns_name. This
/// method should be used only if you need to own the namespace as well
/// and want to keep it long term (like the JIT).
NSPtr getSharedPtrToNS(llvm::StringRef ns_name);
SereneContext() SereneContext()
: pm(&mlirContext), diagEngine(makeDiagnosticEngine(*this)), : pm(&mlirContext), diagEngine(makeDiagnosticEngine(*this)),
targetPhase(CompilationPhase::NoOptimization) { targetPhase(CompilationPhase::NoOptimization) {
@ -158,7 +167,7 @@ public:
assert(ns != nullptr && "Default ns doesn't exit!"); assert(ns != nullptr && "Default ns doesn't exit!");
auto maybeJIT = serene::jit::makeSereneJIT(*ns); auto maybeJIT = serene::jit::makeHalleyJIT(*ctx);
if (!maybeJIT) { if (!maybeJIT) {
// TODO: Raise an error here // TODO: Raise an error here
@ -170,6 +179,8 @@ public:
return ctx; return ctx;
}; };
llvm::Triple getTargetTriple() const { return llvm::Triple(targetTriple); };
private: private:
CompilationPhase targetPhase; CompilationPhase targetPhase;

View File

@ -50,6 +50,7 @@ enum ErrID {
E0011, E0011,
E0012, E0012,
E0013, E0013,
E0014,
}; };
struct ErrorVariant { struct ErrorVariant {
@ -106,6 +107,8 @@ static ErrorVariant
static ErrorVariant static ErrorVariant
InvalidCharacterForSymbol(E0013, "Invalid character for a symbol", ""); InvalidCharacterForSymbol(E0013, "Invalid character for a symbol", "");
static ErrorVariant CompilationError(E0014, "Compilation error!", "");
static std::map<ErrID, ErrorVariant *> ErrDesc = { static std::map<ErrID, ErrorVariant *> ErrDesc = {
{E0000, &UnknownError}, {E0001, &DefExpectSymbol}, {E0000, &UnknownError}, {E0001, &DefExpectSymbol},
{E0002, &DefWrongNumberOfArgs}, {E0003, &FnNoArgsList}, {E0002, &DefWrongNumberOfArgs}, {E0003, &FnNoArgsList},
@ -113,7 +116,8 @@ static std::map<ErrID, ErrorVariant *> ErrDesc = {
{E0006, &DontKnowHowToCallNode}, {E0007, &PassFailureError}, {E0006, &DontKnowHowToCallNode}, {E0007, &PassFailureError},
{E0008, &NSLoadError}, {E0009, &NSAddToSMError}, {E0008, &NSLoadError}, {E0009, &NSAddToSMError},
{E0010, &EOFWhileScaningAList}, {E0011, &InvalidDigitForNumber}, {E0010, &EOFWhileScaningAList}, {E0011, &InvalidDigitForNumber},
{E0012, &TwoFloatPoints}, {E0013, &InvalidCharacterForSymbol}}; {E0012, &TwoFloatPoints}, {E0013, &InvalidCharacterForSymbol},
{E0014, &CompilationError}};
} // namespace errors } // namespace errors
} // namespace serene } // namespace serene

View File

@ -18,16 +18,14 @@
/** /**
* Commentary: * Commentary:
* The code is based on the MLIR's JIT and named after Edmond Halley.
*/ */
#ifndef SERENE_JIT_H #ifndef SERENE_JIT_HALLEY_H
#define SERENE_JIT_H #define SERENE_JIT_HALLEY_H
#include "serene/errors.h" #include "serene/errors.h"
#include "serene/export.h" #include "serene/export.h"
#include "serene/exprs/expression.h"
#include "serene/namespace.h"
#include "serene/slir/generatable.h"
#include "serene/utils.h" #include "serene/utils.h"
#include <llvm/ADT/StringRef.h> #include <llvm/ADT/StringRef.h>
@ -40,13 +38,19 @@
#include <memory> #include <memory>
#define JIT2_LOG(...) \ #define HALLEY_LOG(...) \
DEBUG_WITH_TYPE("JIT", llvm::dbgs() << "[JIT]: " << __VA_ARGS__ << "\n"); DEBUG_WITH_TYPE("halley", llvm::dbgs() \
<< "[HALLEY]: " << __VA_ARGS__ << "\n");
namespace serene { namespace serene {
class SERENE_EXPORT JIT;
using MaybeJIT = llvm::Optional<std::unique_ptr<JIT>>; class SereneContext;
class Namespace;
namespace jit {
class Halley;
using MaybeJIT = llvm::Expected<std::unique_ptr<Halley>>;
/// A simple object cache following Lang's LLJITWithObjectCache example and /// A simple object cache following Lang's LLJITWithObjectCache example and
/// MLIR's SimpelObjectCache. /// MLIR's SimpelObjectCache.
@ -67,12 +71,8 @@ private:
llvm::StringMap<std::unique_ptr<llvm::MemoryBuffer>> cachedObjects; llvm::StringMap<std::unique_ptr<llvm::MemoryBuffer>> cachedObjects;
}; };
class JIT { class SERENE_EXPORT Halley {
// TODO: Should the JIT own the context ???
Namespace &ns;
std::unique_ptr<llvm::orc::LLJIT> engine; std::unique_ptr<llvm::orc::LLJIT> engine;
std::unique_ptr<ObjectCache> cache; std::unique_ptr<ObjectCache> cache;
/// GDB notification listener. /// GDB notification listener.
@ -81,16 +81,19 @@ class JIT {
/// Perf notification listener. /// Perf notification listener.
llvm::JITEventListener *perfListener; llvm::JITEventListener *perfListener;
public: llvm::orc::JITTargetMachineBuilder jtmb;
JIT(Namespace &ns, bool enableObjectCache = true, llvm::DataLayout &dl;
bool enableGDBNotificationListener = true, SereneContext &ctx;
bool enablePerfNotificationListener = true);
static MaybeJIT std::shared_ptr<Namespace> activeNS;
make(Namespace &ns, mlir::ArrayRef<llvm::StringRef> sharedLibPaths = {},
mlir::Optional<llvm::CodeGenOpt::Level> jitCodeGenOptLevel = llvm::None, public:
bool enableObjectCache = true, bool enableGDBNotificationListener = true, Halley(serene::SereneContext &ctx, llvm::orc::JITTargetMachineBuilder &&jtmb,
bool enablePerfNotificationListener = true); llvm::DataLayout &&dl);
// TODO: Read the sharedLibPaths via context
static MaybeJIT make(serene::SereneContext &ctx,
llvm::orc::JITTargetMachineBuilder &&jtmb);
/// Looks up a packed-argument function with the given name and returns a /// Looks up a packed-argument function with the given name and returns a
/// pointer to it. Propagates errors in case of failure. /// pointer to it. Propagates errors in case of failure.
@ -165,9 +168,17 @@ public:
llvm::function_ref<llvm::orc::SymbolMap(llvm::orc::MangleAndInterner)> llvm::function_ref<llvm::orc::SymbolMap(llvm::orc::MangleAndInterner)>
symbolMap); symbolMap);
std::unique_ptr<exprs::Expression> eval(SereneContext &ctx, llvm::Optional<errors::ErrorTree> addNS(Namespace &ns,
std::string input); reader::LocationRange &loc);
llvm::Optional<errors::ErrorTree> addNS(llvm::StringRef nsname,
reader::LocationRange &loc);
Namespace &getActiveNS();
}; };
llvm::Expected<std::unique_ptr<Halley>> makeHalleyJIT(SereneContext &ctx);
} // namespace jit
} // namespace serene } // namespace serene
#endif #endif

View File

@ -48,6 +48,7 @@ add_library(serene
# jit.cpp # jit.cpp
jit/engine.cpp jit/engine.cpp
jit/layers.cpp jit/layers.cpp
jit/halley.cpp
# Reader # Reader
reader/reader.cpp reader/reader.cpp

View File

@ -61,6 +61,14 @@ Namespace &SereneContext::getCurrentNS() {
return *namespaces[this->current_ns]; return *namespaces[this->current_ns];
}; };
NSPtr SereneContext::getSharedPtrToNS(llvm::StringRef ns_name) {
if (namespaces.count(ns_name.str()) == 0) {
return nullptr;
}
return namespaces[ns_name.str()];
};
void SereneContext::setOperationPhase(CompilationPhase phase) { void SereneContext::setOperationPhase(CompilationPhase phase) {
this->targetPhase = phase; this->targetPhase = phase;

View File

@ -22,8 +22,6 @@
#include "serene/jit/layers.h" #include "serene/jit/layers.h"
#include "serene/utils.h" #include "serene/utils.h"
#include <llvm/ExecutionEngine/JITSymbol.h>
#include <memory> #include <memory>
namespace serene::jit { namespace serene::jit {

View File

@ -25,11 +25,16 @@
* license applies here * license applies here
*/ */
#include "serene/jit.h" #include "serene/jit/halley.h"
#include "serene/context.h" #include "serene/context.h"
#include "serene/errors/constants.h"
#include "serene/errors/error.h"
#include "serene/namespace.h" #include "serene/namespace.h"
#include "serene/utils.h"
#include <llvm/ADT/None.h>
#include <llvm/ADT/Optional.h>
#include <llvm/ExecutionEngine/Orc/CompileUtils.h> #include <llvm/ExecutionEngine/Orc/CompileUtils.h>
#include <llvm/ExecutionEngine/Orc/ExecutionUtils.h> #include <llvm/ExecutionEngine/Orc/ExecutionUtils.h>
#include <llvm/ExecutionEngine/Orc/IRCompileLayer.h> #include <llvm/ExecutionEngine/Orc/IRCompileLayer.h>
@ -37,6 +42,9 @@
#include <llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h> #include <llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h>
#include <llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h> #include <llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h>
#include <llvm/ExecutionEngine/SectionMemoryManager.h> #include <llvm/ExecutionEngine/SectionMemoryManager.h>
#include <llvm/IR/Module.h>
#include <llvm/Support/ErrorHandling.h>
#include <llvm/Support/MemoryBuffer.h>
#include <llvm/Support/ToolOutputFile.h> #include <llvm/Support/ToolOutputFile.h>
#include <mlir/ExecutionEngine/ExecutionEngine.h> #include <mlir/ExecutionEngine/ExecutionEngine.h>
#include <mlir/Support/FileUtilities.h> #include <mlir/Support/FileUtilities.h>
@ -49,6 +57,7 @@
namespace serene { namespace serene {
namespace jit {
// TODO: Remove this function and replace it by our own version of // TODO: Remove this function and replace it by our own version of
// error handler // error handler
/// Wrap a string into an llvm::StringError. /// Wrap a string into an llvm::StringError.
@ -61,70 +70,73 @@ static std::string makePackedFunctionName(llvm::StringRef name) {
return "_serene_" + name.str(); return "_serene_" + name.str();
} }
static void packFunctionArguments(llvm::Module *module) { // static void packFunctionArguments(llvm::Module *module) {
auto &ctx = module->getContext(); // auto &ctx = module->getContext();
llvm::IRBuilder<> builder(ctx); // llvm::IRBuilder<> builder(ctx);
llvm::DenseSet<llvm::Function *> interfaceFunctions; // llvm::DenseSet<llvm::Function *> interfaceFunctions;
for (auto &func : module->getFunctionList()) { // for (auto &func : module->getFunctionList()) {
if (func.isDeclaration()) { // if (func.isDeclaration()) {
continue; // continue;
} // }
if (interfaceFunctions.count(&func) != 0) { // if (interfaceFunctions.count(&func) != 0) {
continue; // continue;
} // }
// Given a function `foo(<...>)`, define the interface function // // Given a function `foo(<...>)`, define the interface function
// `mlir_foo(i8**)`. // // `serene_foo(i8**)`.
auto *newType = llvm::FunctionType::get( // auto *newType = llvm::FunctionType::get(
builder.getVoidTy(), builder.getInt8PtrTy()->getPointerTo(), // builder.getVoidTy(), builder.getInt8PtrTy()->getPointerTo(),
/*isVarArg=*/false); // /*isVarArg=*/false);
auto newName = makePackedFunctionName(func.getName()); // auto newName = makePackedFunctionName(func.getName());
auto funcCst = module->getOrInsertFunction(newName, newType); // auto funcCst = module->getOrInsertFunction(newName, newType);
llvm::Function *interfaceFunc = // llvm::Function *interfaceFunc =
llvm::cast<llvm::Function>(funcCst.getCallee()); // llvm::cast<llvm::Function>(funcCst.getCallee());
interfaceFunctions.insert(interfaceFunc); // interfaceFunctions.insert(interfaceFunc);
// Extract the arguments from the type-erased argument list and cast them to // // Extract the arguments from the type-erased argument list and cast them
// the proper types. // to
auto *bb = llvm::BasicBlock::Create(ctx); // // the proper types.
bb->insertInto(interfaceFunc); // auto *bb = llvm::BasicBlock::Create(ctx);
builder.SetInsertPoint(bb); // bb->insertInto(interfaceFunc);
llvm::Value *argList = interfaceFunc->arg_begin(); // builder.SetInsertPoint(bb);
llvm::SmallVector<llvm::Value *, COMMON_ARGS_COUNT> args; // llvm::Value *argList = interfaceFunc->arg_begin();
args.reserve(llvm::size(func.args())); // llvm::SmallVector<llvm::Value *, COMMON_ARGS_COUNT> args;
for (auto &indexedArg : llvm::enumerate(func.args())) { // args.reserve(llvm::size(func.args()));
llvm::Value *argIndex = llvm::Constant::getIntegerValue( // for (auto &indexedArg : llvm::enumerate(func.args())) {
builder.getInt64Ty(), llvm::APInt(I64_BIT_SIZE, indexedArg.index())); // llvm::Value *argIndex = llvm::Constant::getIntegerValue(
llvm::Value *argPtrPtr = // builder.getInt64Ty(), llvm::APInt(I64_BIT_SIZE,
builder.CreateGEP(builder.getInt8PtrTy(), argList, argIndex); // indexedArg.index()));
llvm::Value *argPtr = // llvm::Value *argPtrPtr =
builder.CreateLoad(builder.getInt8PtrTy(), argPtrPtr); // builder.CreateGEP(builder.getInt8PtrTy(), argList, argIndex);
llvm::Type *argTy = indexedArg.value().getType(); // llvm::Value *argPtr =
argPtr = builder.CreateBitCast(argPtr, argTy->getPointerTo()); // builder.CreateLoad(builder.getInt8PtrTy(), argPtrPtr);
llvm::Value *arg = builder.CreateLoad(argTy, argPtr); // llvm::Type *argTy = indexedArg.value().getType();
args.push_back(arg); // argPtr = builder.CreateBitCast(argPtr,
} // argTy->getPointerTo()); llvm::Value *arg = builder.CreateLoad(argTy,
// argPtr); args.push_back(arg);
// }
// Call the implementation function with the extracted arguments. // // Call the implementation function with the extracted arguments.
llvm::Value *result = builder.CreateCall(&func, args); // llvm::Value *result = builder.CreateCall(&func, args);
// Assuming the result is one value, potentially of type `void`. // // Assuming the result is one value, potentially of type `void`.
if (!result->getType()->isVoidTy()) { // if (!result->getType()->isVoidTy()) {
llvm::Value *retIndex = llvm::Constant::getIntegerValue( // llvm::Value *retIndex = llvm::Constant::getIntegerValue(
builder.getInt64Ty(), // builder.getInt64Ty(),
llvm::APInt(I64_BIT_SIZE, llvm::size(func.args()))); // llvm::APInt(I64_BIT_SIZE, llvm::size(func.args())));
llvm::Value *retPtrPtr = // llvm::Value *retPtrPtr =
builder.CreateGEP(builder.getInt8PtrTy(), argList, retIndex); // builder.CreateGEP(builder.getInt8PtrTy(), argList, retIndex);
llvm::Value *retPtr = // llvm::Value *retPtr =
builder.CreateLoad(builder.getInt8PtrTy(), retPtrPtr); // builder.CreateLoad(builder.getInt8PtrTy(), retPtrPtr);
retPtr = builder.CreateBitCast(retPtr, result->getType()->getPointerTo()); // retPtr = builder.CreateBitCast(retPtr,
builder.CreateStore(result, retPtr); // result->getType()->getPointerTo()); builder.CreateStore(result,
} // retPtr);
// }
// The interface function returns void. // // The interface function returns void.
builder.CreateRetVoid(); // builder.CreateRetVoid();
} // }
}; // };
void ObjectCache::notifyObjectCompiled(const llvm::Module *m, void ObjectCache::notifyObjectCompiled(const llvm::Module *m,
llvm::MemoryBufferRef objBuffer) { llvm::MemoryBufferRef objBuffer) {
@ -138,11 +150,12 @@ ObjectCache::getObject(const llvm::Module *m) {
auto i = cachedObjects.find(m->getModuleIdentifier()); auto i = cachedObjects.find(m->getModuleIdentifier());
if (i == cachedObjects.end()) { if (i == cachedObjects.end()) {
JIT_LOG("No object for " << m->getModuleIdentifier() HALLEY_LOG("No object for " + m->getModuleIdentifier() +
<< " in cache. Compiling.\n"); " in cache. Compiling.\n");
return nullptr; return nullptr;
} }
JIT_LOG("Object for " << m->getModuleIdentifier() << " loaded from cache.\n"); HALLEY_LOG("Object for " + m->getModuleIdentifier() +
" loaded from cache.\n");
return llvm::MemoryBuffer::getMemBuffer(i->second->getMemBufferRef()); return llvm::MemoryBuffer::getMemBuffer(i->second->getMemBufferRef());
} }
@ -164,144 +177,22 @@ void ObjectCache::dumpToObjectFile(llvm::StringRef outputFilename) {
file->keep(); file->keep();
} }
JIT::JIT(Namespace &ns, bool enableObjectCache, Halley::Halley(serene::SereneContext &ctx,
bool enableGDBNotificationListener, llvm::orc::JITTargetMachineBuilder &&jtmb, llvm::DataLayout &&dl)
bool enablePerfNotificationListener) : cache(ctx.opts.JITenableObjectCache ? new ObjectCache() : nullptr),
: ns(ns), cache(enableObjectCache ? new ObjectCache() : nullptr), gdbListener(ctx.opts.JITenableGDBNotificationListener
gdbListener(enableGDBNotificationListener
? llvm::JITEventListener::createGDBRegistrationListener() ? llvm::JITEventListener::createGDBRegistrationListener()
: nullptr), : nullptr),
perfListener(enablePerfNotificationListener perfListener(ctx.opts.JITenablePerfNotificationListener
? llvm::JITEventListener::createPerfJITEventListener() ? llvm::JITEventListener::createPerfJITEventListener()
: nullptr){}; : nullptr),
jtmb(jtmb), dl(dl), ctx(ctx),
MaybeJIT JIT::make(Namespace &ns, activeNS(ctx.getSharedPtrToNS(ctx.getCurrentNS().name)) {
mlir::ArrayRef<llvm::StringRef> sharedLibPaths, assert(activeNS != nullptr && "Active NS is null!!!");
mlir::Optional<llvm::CodeGenOpt::Level> jitCodeGenOptLevel,
bool enableObjectCache, bool enableGDBNotificationListener,
bool enablePerfNotificationListener) {
auto jitEngine = std::make_unique<JIT>(ns, enableObjectCache,
enableGDBNotificationListener,
enablePerfNotificationListener);
// Why not the llvmcontext from the SereneContext??
// Sice we're going to pass the ownership of this context to a thread
// safe module later on and we will only create the jit function wrappers
// with it, then it is fine to use a new context.
//
// What might go wrong?
// in a repl env when we have to create new modules on top of each other
// having two different contex might be a problem, but i think since we
// use the first context to generate the IR and the second one to just
// run it.
std::unique_ptr<llvm::LLVMContext> ctx(new llvm::LLVMContext);
auto maybeModule = jitEngine->ns.compileToLLVM();
if (!maybeModule.hasValue()) {
return llvm::None;
}
auto llvmModule = std::move(maybeModule.getValue());
packFunctionArguments(llvmModule.get());
auto dataLayout = llvmModule->getDataLayout();
// Callback to create the object layer with symbol resolution to current
// process and dynamically linked libraries.
auto objectLinkingLayerCreator = [&](llvm::orc::ExecutionSession &session,
const llvm::Triple &tt) {
UNUSED(tt);
auto objectLayer =
std::make_unique<llvm::orc::RTDyldObjectLinkingLayer>(session, []() {
return std::make_unique<llvm::SectionMemoryManager>();
});
// Register JIT event listeners if they are enabled.
if (jitEngine->gdbListener != nullptr) {
objectLayer->registerJITEventListener(*jitEngine->gdbListener);
}
if (jitEngine->perfListener != nullptr) {
objectLayer->registerJITEventListener(*jitEngine->perfListener);
}
// COFF format binaries (Windows) need special handling to deal with
// exported symbol visibility.
// cf llvm/lib/ExecutionEngine/Orc/LLJIT.cpp LLJIT::createObjectLinkingLayer
llvm::Triple targetTriple(llvm::Twine(llvmModule->getTargetTriple()));
if (targetTriple.isOSBinFormatCOFF()) {
objectLayer->setOverrideObjectFlagsWithResponsibilityFlags(true);
objectLayer->setAutoClaimResponsibilityForObjectSymbols(true);
}
// Resolve symbols from shared libraries.
for (auto libPath : sharedLibPaths) {
auto mb = llvm::MemoryBuffer::getFile(libPath);
if (!mb) {
llvm::errs() << "Failed to create MemoryBuffer for: " << libPath
<< "\nError: " << mb.getError().message() << "\n";
continue;
}
auto &JD = session.createBareJITDylib(std::string(libPath));
auto loaded = llvm::orc::DynamicLibrarySearchGenerator::Load(
libPath.data(), dataLayout.getGlobalPrefix());
if (!loaded) {
llvm::errs() << "Could not load " << libPath << ":\n "
<< loaded.takeError() << "\n";
continue;
}
JD.addGenerator(std::move(*loaded));
cantFail(objectLayer->add(JD, std::move(mb.get())));
}
return objectLayer;
};
// Callback to inspect the cache and recompile on demand. This follows Lang's
// LLJITWithObjectCache example.
auto compileFunctionCreator = [&](llvm::orc::JITTargetMachineBuilder JTMB)
-> llvm::Expected<
std::unique_ptr<llvm::orc::IRCompileLayer::IRCompiler>> {
if (jitCodeGenOptLevel) {
JTMB.setCodeGenOptLevel(jitCodeGenOptLevel.getValue());
}
auto TM = JTMB.createTargetMachine();
if (!TM) {
return TM.takeError();
}
return std::make_unique<llvm::orc::TMOwningSimpleCompiler>(
std::move(*TM), jitEngine->cache.get());
};
// Create the LLJIT by calling the LLJITBuilder with 2 callbacks.
auto jit =
cantFail(llvm::orc::LLJITBuilder()
.setCompileFunctionCreator(compileFunctionCreator)
.setObjectLinkingLayerCreator(objectLinkingLayerCreator)
.create());
// Add a ThreadSafemodule to the engine and return.
llvm::orc::ThreadSafeModule tsm(std::move(llvmModule), std::move(ctx));
// TODO: Do we need a module transformer here???
cantFail(jit->addIRModule(std::move(tsm)));
jitEngine->engine = std::move(jit);
// Resolve symbols that are statically linked in the current process.
llvm::orc::JITDylib &mainJD = jitEngine->engine->getMainJITDylib();
mainJD.addGenerator(
cantFail(llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess(
dataLayout.getGlobalPrefix())));
return MaybeJIT(std::move(jitEngine));
}; };
llvm::Expected<void (*)(void **)> JIT::lookup(llvm::StringRef name) const { llvm::Expected<void (*)(void **)> Halley::lookup(llvm::StringRef name) const {
auto expectedSymbol = engine->lookup(makePackedFunctionName(name)); auto expectedSymbol = engine->lookup(makePackedFunctionName(name));
// JIT lookup may return an Error referring to strings stored internally by // JIT lookup may return an Error referring to strings stored internally by
@ -329,8 +220,8 @@ llvm::Expected<void (*)(void **)> JIT::lookup(llvm::StringRef name) const {
return fptr; return fptr;
} }
llvm::Error JIT::invokePacked(llvm::StringRef name, llvm::Error Halley::invokePacked(llvm::StringRef name,
llvm::MutableArrayRef<void *> args) const { llvm::MutableArrayRef<void *> args) const {
auto expectedFPtr = lookup(name); auto expectedFPtr = lookup(name);
if (!expectedFPtr) { if (!expectedFPtr) {
return expectedFPtr.takeError(); return expectedFPtr.takeError();
@ -342,13 +233,173 @@ llvm::Error JIT::invokePacked(llvm::StringRef name,
return llvm::Error::success(); return llvm::Error::success();
} }
void JIT::registerSymbols( void Halley::registerSymbols(
llvm::function_ref<llvm::orc::SymbolMap(llvm::orc::MangleAndInterner)> llvm::function_ref<llvm::orc::SymbolMap(llvm::orc::MangleAndInterner)>
symbolMap) { symbolMap) {
auto &mainJitDylib = engine->getMainJITDylib(); auto &mainJitDylib = engine->getMainJITDylib();
cantFail(mainJitDylib.define( cantFail(mainJitDylib.define(
absoluteSymbols(symbolMap(llvm::orc::MangleAndInterner( absoluteSymbols(symbolMap(llvm::orc::MangleAndInterner(
mainJitDylib.getExecutionSession(), engine->getDataLayout()))))); mainJitDylib.getExecutionSession(), engine->getDataLayout())))));
} };
llvm::Optional<errors::ErrorTree> Halley::addNS(Namespace &ns,
reader::LocationRange &loc) {
// TODO: Fix compileToLLVM to return proper errors
auto maybeModule = ns.compileToLLVM();
if (!maybeModule.hasValue()) {
return errors::makeErrorTree(loc, errors::CompilationError);
}
auto tsm = std::move(maybeModule.getValue());
// tsm.withModuleDo([](llvm::Module &m) { packFunctionArguments(&m); });
// TODO: Make sure that the data layout of the module is the same as the
// engine
cantFail(engine->addIRModule(std::move(tsm)));
return llvm::None;
};
llvm::Optional<errors::ErrorTree> Halley::addNS(llvm::StringRef nsname,
reader::LocationRange &loc) {
auto maybeNS = ctx.sourceManager.readNamespace(ctx, nsname.str(), loc);
if (!maybeNS) {
// TODO: Fix this by making Serene errors compatible with llvm::Error
auto err = maybeNS.getError();
return err;
}
auto &ns = maybeNS.getValue();
auto err = addNS(*ns, loc);
if (err) {
return err.getValue();
}
return llvm::None;
};
MaybeJIT Halley::make(SereneContext &serene_ctx,
llvm::orc::JITTargetMachineBuilder &&jtmb) {
auto dl = jtmb.getDefaultDataLayoutForTarget();
if (!dl) {
return dl.takeError();
}
auto jitEngine =
std::make_unique<Halley>(serene_ctx, std::move(jtmb), std::move(*dl));
// Why not the llvmcontext from the SereneContext??
// Sice we're going to pass the ownership of this context to a thread
// safe module later on and we will only create the jit function wrappers
// with it, then it is fine to use a new context.
//
// What might go wrong?
// in a repl env when we have to create new modules on top of each other
// having two different contex might be a problem, but i think since we
// use the first context to generate the IR and the second one to just
// run it.
std::unique_ptr<llvm::LLVMContext> ctx(new llvm::LLVMContext);
// Callback to create the object layer with symbol resolution to current
// process and dynamically linked libraries.
auto objectLinkingLayerCreator = [&](llvm::orc::ExecutionSession &session,
const llvm::Triple &tt) {
UNUSED(tt);
auto objectLayer =
std::make_unique<llvm::orc::RTDyldObjectLinkingLayer>(session, []() {
return std::make_unique<llvm::SectionMemoryManager>();
});
// Register JIT event listeners if they are enabled.
if (jitEngine->gdbListener != nullptr) {
objectLayer->registerJITEventListener(*jitEngine->gdbListener);
}
if (jitEngine->perfListener != nullptr) {
objectLayer->registerJITEventListener(*jitEngine->perfListener);
}
// COFF format binaries (Windows) need special handling to deal with
// exported symbol visibility.
// cf llvm/lib/ExecutionEngine/Orc/LLJIT.cpp
// LLJIT::createObjectLinkingLayer
llvm::Triple targetTriple(llvm::Twine(serene_ctx.targetTriple));
if (targetTriple.isOSBinFormatCOFF()) {
objectLayer->setOverrideObjectFlagsWithResponsibilityFlags(true);
objectLayer->setAutoClaimResponsibilityForObjectSymbols(true);
}
// Resolve symbols from shared libraries.
// for (auto libPath : sharedLibPaths) {
// auto mb = llvm::MemoryBuffer::getFile(libPath);
// if (!mb) {
// llvm::errs() << "Failed to create MemoryBuffer for: " << libPath
// << "\nError: " << mb.getError().message() << "\n";
// continue;
// }
// auto &JD = session.createBareJITDylib(std::string(libPath));
// auto loaded = llvm::orc::DynamicLibrarySearchGenerator::Load(
// libPath.data(), dataLayout.getGlobalPrefix());
// if (!loaded) {
// llvm::errs() << "Could not load " << libPath << ":\n "
// << loaded.takeError() << "\n";
// continue;
// }
// JD.addGenerator(std::move(*loaded));
// cantFail(objectLayer->add(JD, std::move(mb.get())));
// }
return objectLayer;
};
// Callback to inspect the cache and recompile on demand. This follows Lang's
// LLJITWithObjectCache example.
auto compileFunctionCreator = [&](llvm::orc::JITTargetMachineBuilder JTMB)
-> llvm::Expected<
std::unique_ptr<llvm::orc::IRCompileLayer::IRCompiler>> {
llvm::CodeGenOpt::Level jitCodeGenOptLevel =
static_cast<llvm::CodeGenOpt::Level>(serene_ctx.getOptimizatioLevel());
JTMB.setCodeGenOptLevel(jitCodeGenOptLevel);
auto targetMachine = JTMB.createTargetMachine();
if (!targetMachine) {
return targetMachine.takeError();
}
return std::make_unique<llvm::orc::TMOwningSimpleCompiler>(
std::move(*targetMachine), jitEngine->cache.get());
};
// Create the LLJIT by calling the LLJITBuilder with 2 callbacks.
auto jit =
cantFail(llvm::orc::LLJITBuilder()
.setCompileFunctionCreator(compileFunctionCreator)
.setObjectLinkingLayerCreator(objectLinkingLayerCreator)
.create());
jitEngine->engine = std::move(jit);
// Resolve symbols that are statically linked in the current process.
llvm::orc::JITDylib &mainJD = jitEngine->engine->getMainJITDylib();
mainJD.addGenerator(
cantFail(llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess(
jitEngine->dl.getGlobalPrefix())));
return MaybeJIT(std::move(jitEngine));
};
Namespace &Halley::getActiveNS() { return *activeNS; };
llvm::Expected<std::unique_ptr<Halley>> makeHalleyJIT(SereneContext &ctx) {
llvm::orc::JITTargetMachineBuilder jtmb(ctx.getTargetTriple());
return Halley::make(ctx, std::move(jtmb));
};
} // namespace jit
} // namespace serene } // namespace serene

View File

@ -19,8 +19,8 @@
#include "serene/passes.h" #include "serene/passes.h"
#include "serene/slir/dialect.h" #include "serene/slir/dialect.h"
#include <memory>
#include <mlir/Conversion/AffineToStandard/AffineToStandard.h> #include <mlir/Conversion/AffineToStandard/AffineToStandard.h>
#include <mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h>
#include <mlir/Conversion/LLVMCommon/ConversionTarget.h> #include <mlir/Conversion/LLVMCommon/ConversionTarget.h>
#include <mlir/Conversion/LLVMCommon/TypeConverter.h> #include <mlir/Conversion/LLVMCommon/TypeConverter.h>
#include <mlir/Conversion/SCFToStandard/SCFToStandard.h> #include <mlir/Conversion/SCFToStandard/SCFToStandard.h>
@ -32,6 +32,8 @@
#include <mlir/Pass/Pass.h> #include <mlir/Pass/Pass.h>
#include <mlir/Transforms/DialectConversion.h> #include <mlir/Transforms/DialectConversion.h>
#include <memory>
namespace serene::passes { namespace serene::passes {
struct SLIRToLLVMDialect struct SLIRToLLVMDialect
: public mlir::PassWrapper<SLIRToLLVMDialect, : public mlir::PassWrapper<SLIRToLLVMDialect,
@ -68,7 +70,8 @@ void SLIRToLLVMDialect::runOnOperation() {
mlir::RewritePatternSet patterns(&getContext()); mlir::RewritePatternSet patterns(&getContext());
populateStdToLLVMConversionPatterns(typeConverter, patterns); populateStdToLLVMConversionPatterns(typeConverter, patterns);
mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
patterns);
// patterns.add<PrintOpLowering>(&getContext()); // patterns.add<PrintOpLowering>(&getContext());
// We want to completely lower to LLVM, so we use a `FullConversion`. This // We want to completely lower to LLVM, so we use a `FullConversion`. This

View File

@ -109,32 +109,30 @@ SERENE_EXPORT exprs::MaybeNode eval(SereneContext &ctx, exprs::Ast &input) {
UNUSED(input); UNUSED(input);
auto loc = reader::LocationRange::UnknownLocation("nsname"); auto loc = reader::LocationRange::UnknownLocation("nsname");
auto err = ctx.jit->addNS("docs.examples.hello_world"); auto err = ctx.jit->addNS("docs.examples.hello_world", loc);
if (err) { if (err) {
llvm::errs() << err; auto es = err.getValue();
auto e = errors::makeErrorTree(loc, errors::NSLoadError); auto nsloadErr = errors::makeError(loc, errors::NSLoadError);
es.push_back(nsloadErr);
return exprs::makeErrorNode(loc, errors::NSLoadError); return exprs::MaybeNode::error(es);
} }
std::string tmp("main"); std::string tmp("main");
llvm::ExitOnError e; llvm::ExitOnError e;
// Get the anonymous expression's JITSymbol. // Get the anonymous expression's JITSymbol.
auto sym = e(ctx.jit->lookup(tmp)); auto sym = e(ctx.jit->lookup(tmp));
llvm::outs() << "eval here\n"; llvm::outs() << "eval here\n";
// Get the symbol's address and cast it to the right type (takes no
// arguments, returns a double) so we can call it as a native function.
auto *f = (int (*)())(intptr_t)sym.getAddress();
f(); sym((void **)3);
err = ctx.jit->addAst(input); // err = ctx.jit->addAst(input);
if (err) { // if (err) {
llvm::errs() << err; // llvm::errs() << err;
auto e = errors::makeErrorTree(loc, errors::NSLoadError); // auto e = errors::makeErrorTree(loc, errors::NSLoadError);
return exprs::makeErrorNode(loc, errors::NSLoadError); // return exprs::makeErrorNode(loc, errors::NSLoadError);
} // }
return exprs::make<exprs::Number>(loc, "4", false, false); return exprs::make<exprs::Number>(loc, "4", false, false);
}; };

View File

@ -73,7 +73,7 @@ int main(int argc, char *argv[]) {
// Read line // Read line
std::string line; std::string line;
std::string result; std::string result;
std::string prompt = ctx->jit->getCurrentNS().name + "> "; std::string prompt = ctx->jit->getActiveNS().name + "> ";
auto quit = linenoise::Readline(prompt.c_str(), line); auto quit = linenoise::Readline(prompt.c_str(), line);

View File

@ -16,7 +16,7 @@
* along with this program. If not, see <http://www.gnu.org/licenses/>. * along with this program. If not, see <http://www.gnu.org/licenses/>.
*/ */
#include "serene/jit.h" #include "serene/jit/halley.h"
#include "serene/namespace.h" #include "serene/namespace.h"
#include "serene/reader/location.h" #include "serene/reader/location.h"
#include "serene/reader/reader.h" #include "serene/reader/reader.h"