Clean up the slir lowering impl
This commit is contained in:
parent
5a9c256325
commit
e5c5a68128
|
@ -404,7 +404,8 @@ There will be an episode dedicated to eache of these
|
|||
- [X] Define the operations
|
||||
- [X] Walk the AST and generate the operations
|
||||
|
||||
* Episode 10 - Pass Infrastructure
|
||||
* DONE Episode 10 - Pass Infrastructure
|
||||
CLOSED: [2021-10-15 Fri 14:17]
|
||||
** The next Step
|
||||
** Updates:
|
||||
*** CMake changes
|
||||
|
@ -494,8 +495,15 @@ Source code -> IR X -> IR Y -> IR Z -> ... -> Target Code
|
|||
#+END_SRC
|
||||
|
||||
* Episode 11 - Lowering SLIR
|
||||
** Overview
|
||||
** Dialect lowering
|
||||
*** Why?
|
||||
*** Transforming a dialect to another dialect or LLVM IR
|
||||
*** The goal is to lower SLIR to LLVM IR directly or indirectly.
|
||||
** Dialect Conversions
|
||||
This framework allows for transforming a set of illegal operations to a set of legal ones.
|
||||
*** Target Conversion
|
||||
*** Rewrite Patterns
|
||||
*** Type Converter
|
||||
** Full vs Partial Conversion
|
||||
** Dealing with Pass failures
|
||||
|
|
|
@ -49,9 +49,7 @@ namespace serene {
|
|||
|
||||
Namespace::Namespace(SereneContext &ctx, llvm::StringRef ns_name,
|
||||
llvm::Optional<llvm::StringRef> filename)
|
||||
: ctx(ctx), name(ns_name)
|
||||
|
||||
{
|
||||
: ctx(ctx), name(ns_name) {
|
||||
if (filename.hasValue()) {
|
||||
this->filename.emplace(filename.getValue().str());
|
||||
}
|
||||
|
|
|
@ -31,6 +31,8 @@
|
|||
|
||||
namespace serene::passes {
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// ValueOp lowering to constant op
|
||||
struct ValueOpLowering : public mlir::OpRewritePattern<serene::slir::ValueOp> {
|
||||
using OpRewritePattern<serene::slir::ValueOp>::OpRewritePattern;
|
||||
|
||||
|
@ -45,9 +47,11 @@ ValueOpLowering::matchAndRewrite(serene::slir::ValueOp op,
|
|||
auto value = op.value();
|
||||
mlir::Location loc = op.getLoc();
|
||||
|
||||
llvm::SmallVector<mlir::Type, 4> arg_types(0);
|
||||
llvm::SmallVector<mlir::Type, 1> arg_types(0);
|
||||
auto func_type = rewriter.getFunctionType(arg_types, rewriter.getI64Type());
|
||||
auto fn = rewriter.create<mlir::FuncOp>(loc, "randomname", func_type);
|
||||
// TODO: use a mechanism to generate unique names
|
||||
auto fn = rewriter.create<mlir::FuncOp>(loc, "randomname", func_type);
|
||||
|
||||
if (!fn) {
|
||||
op.emitOpError("Value Rewrite fn is null");
|
||||
return mlir::failure();
|
||||
|
@ -55,6 +59,8 @@ ValueOpLowering::matchAndRewrite(serene::slir::ValueOp op,
|
|||
|
||||
auto entryBlock = fn.addEntryBlock();
|
||||
rewriter.setInsertionPointToStart(entryBlock);
|
||||
|
||||
// Since we only support i64 at the moment we use ConstantIntOp
|
||||
auto retVal = rewriter
|
||||
.create<mlir::ConstantIntOp>(loc, (int64_t)value,
|
||||
rewriter.getI64Type())
|
||||
|
@ -68,10 +74,13 @@ ValueOpLowering::matchAndRewrite(serene::slir::ValueOp op,
|
|||
}
|
||||
|
||||
fn.setPrivate();
|
||||
|
||||
// Erase the original ValueOP
|
||||
rewriter.eraseOp(op);
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Fn lowering pattern
|
||||
struct FnOpLowering : public mlir::OpRewritePattern<serene::slir::FnOp> {
|
||||
using OpRewritePattern<serene::slir::FnOp>::OpRewritePattern;
|
||||
|
@ -127,7 +136,9 @@ FnOpLowering::matchAndRewrite(serene::slir::FnOp op,
|
|||
return mlir::success();
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// SLIR lowering pass
|
||||
// This Pass will lower SLIR to MLIR's standard dialect.
|
||||
struct SLIRToMLIRPass
|
||||
: public mlir::PassWrapper<SLIRToMLIRPass,
|
||||
mlir::OperationPass<mlir::ModuleOp>> {
|
||||
|
@ -137,6 +148,8 @@ struct SLIRToMLIRPass
|
|||
mlir::ModuleOp getModule();
|
||||
};
|
||||
|
||||
// Mark what dialects we need for this pass. It's basically translate to what
|
||||
// dialects do we want to lower to
|
||||
void SLIRToMLIRPass::getDependentDialects(
|
||||
mlir::DialectRegistry ®istry) const {
|
||||
registry.insert<mlir::StandardOpsDialect>();
|
||||
|
@ -156,21 +169,24 @@ void SLIRToMLIRPass::runOnModule() {
|
|||
mlir::ConversionTarget target(getContext());
|
||||
|
||||
// We define the specific operations, or dialects, that are legal targets for
|
||||
// this lowering. In our case, we are lowering to a combination of the
|
||||
// `Affine`, `MemRef` and `Standard` dialects.
|
||||
// this lowering. In our case, we are lowering to the `Standard` dialects.
|
||||
target.addLegalDialect<mlir::StandardOpsDialect>();
|
||||
|
||||
// We also define the Toy dialect as Illegal so that the conversion will fail
|
||||
// if any of these operations are *not* converted. Given that we actually want
|
||||
// a partial lowering, we explicitly mark the Toy operations that don't want
|
||||
// to lower, `toy.print`, as `legal`.
|
||||
// We also define the SLIR dialect as Illegal so that the conversion will fail
|
||||
// if any of these operations are *not* converted.
|
||||
target.addIllegalDialect<serene::slir::SereneDialect>();
|
||||
|
||||
// Mark operations that are LEGAL for this pass. It means that we don't lower
|
||||
// them is this pass but we will in another pass. So we don't want to get
|
||||
// an error since we are not lowering them.
|
||||
// target.addLegalOp<serene::slir::PrintOp>();
|
||||
target.addLegalOp<mlir::FuncOp>();
|
||||
|
||||
// Now that the conversion target has been defined, we just need to provide
|
||||
// the set of patterns that will lower the Toy operations.
|
||||
// the set of patterns that will lower the SLIR operations.
|
||||
mlir::RewritePatternSet patterns(&getContext());
|
||||
|
||||
// Pattern to lower ValueOp and FnOp
|
||||
patterns.add<ValueOpLowering, FnOpLowering>(&getContext());
|
||||
|
||||
// With the target and rewrite patterns defined, we can now attempt the
|
||||
|
|
|
@ -37,7 +37,7 @@ struct SLIRToLLVMDialect
|
|||
: public mlir::PassWrapper<SLIRToLLVMDialect,
|
||||
mlir::OperationPass<mlir::ModuleOp>> {
|
||||
void getDependentDialects(mlir::DialectRegistry ®istry) const override {
|
||||
registry.insert<mlir::LLVM::LLVMDialect, mlir::scf::SCFDialect>();
|
||||
registry.insert<mlir::LLVM::LLVMDialect>();
|
||||
}
|
||||
|
||||
void runOnOperation() final;
|
||||
|
@ -67,8 +67,6 @@ void SLIRToLLVMDialect::runOnOperation() {
|
|||
// set of legal ones.
|
||||
mlir::RewritePatternSet patterns(&getContext());
|
||||
|
||||
// mlir::populateAffineToStdConversionPatterns(patterns);
|
||||
// populateLoopToStdConversionPatterns(patterns);
|
||||
populateStdToLLVMConversionPatterns(typeConverter, patterns);
|
||||
|
||||
// patterns.add<PrintOpLowering>(&getContext());
|
||||
|
|
Loading…
Reference in New Issue