From e5c5a681282cb78e55e1f7654264fdefcbf000c4 Mon Sep 17 00:00:00 2001 From: Sameer Rahmani Date: Sat, 16 Oct 2021 16:15:56 +0100 Subject: [PATCH] Clean up the slir lowering impl --- docs/videos.org | 10 ++++++- src/libserene/namespace.cpp | 4 +-- src/libserene/passes/slir_lowering.cpp | 34 +++++++++++++++++------- src/libserene/passes/to_llvm_dialect.cpp | 4 +-- 4 files changed, 36 insertions(+), 16 deletions(-) diff --git a/docs/videos.org b/docs/videos.org index f76857a..6fe079c 100644 --- a/docs/videos.org +++ b/docs/videos.org @@ -404,7 +404,8 @@ There will be an episode dedicated to eache of these - [X] Define the operations - [X] Walk the AST and generate the operations -* Episode 10 - Pass Infrastructure +* DONE Episode 10 - Pass Infrastructure +CLOSED: [2021-10-15 Fri 14:17] ** The next Step ** Updates: *** CMake changes @@ -494,8 +495,15 @@ Source code -> IR X -> IR Y -> IR Z -> ... -> Target Code #+END_SRC * Episode 11 - Lowering SLIR +** Overview ** Dialect lowering *** Why? *** Transforming a dialect to another dialect or LLVM IR *** The goal is to lower SLIR to LLVM IR directly or indirectly. +** Dialect Conversions +This framework allows for transforming a set of illegal operations to a set of legal ones. +*** Target Conversion +*** Rewrite Patterns +*** Type Converter +** Full vs Partial Conversion ** Dealing with Pass failures diff --git a/src/libserene/namespace.cpp b/src/libserene/namespace.cpp index 6599d7e..adb6134 100644 --- a/src/libserene/namespace.cpp +++ b/src/libserene/namespace.cpp @@ -49,9 +49,7 @@ namespace serene { Namespace::Namespace(SereneContext &ctx, llvm::StringRef ns_name, llvm::Optional filename) - : ctx(ctx), name(ns_name) - -{ + : ctx(ctx), name(ns_name) { if (filename.hasValue()) { this->filename.emplace(filename.getValue().str()); } diff --git a/src/libserene/passes/slir_lowering.cpp b/src/libserene/passes/slir_lowering.cpp index 9f96414..7a11613 100644 --- a/src/libserene/passes/slir_lowering.cpp +++ b/src/libserene/passes/slir_lowering.cpp @@ -31,6 +31,8 @@ namespace serene::passes { +// ---------------------------------------------------------------------------- +// ValueOp lowering to constant op struct ValueOpLowering : public mlir::OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -45,9 +47,11 @@ ValueOpLowering::matchAndRewrite(serene::slir::ValueOp op, auto value = op.value(); mlir::Location loc = op.getLoc(); - llvm::SmallVector arg_types(0); + llvm::SmallVector arg_types(0); auto func_type = rewriter.getFunctionType(arg_types, rewriter.getI64Type()); - auto fn = rewriter.create(loc, "randomname", func_type); + // TODO: use a mechanism to generate unique names + auto fn = rewriter.create(loc, "randomname", func_type); + if (!fn) { op.emitOpError("Value Rewrite fn is null"); return mlir::failure(); @@ -55,6 +59,8 @@ ValueOpLowering::matchAndRewrite(serene::slir::ValueOp op, auto entryBlock = fn.addEntryBlock(); rewriter.setInsertionPointToStart(entryBlock); + + // Since we only support i64 at the moment we use ConstantIntOp auto retVal = rewriter .create(loc, (int64_t)value, rewriter.getI64Type()) @@ -68,10 +74,13 @@ ValueOpLowering::matchAndRewrite(serene::slir::ValueOp op, } fn.setPrivate(); + + // Erase the original ValueOP rewriter.eraseOp(op); return mlir::success(); } +// ---------------------------------------------------------------------------- // Fn lowering pattern struct FnOpLowering : public mlir::OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -127,7 +136,9 @@ FnOpLowering::matchAndRewrite(serene::slir::FnOp op, return mlir::success(); } +// ---------------------------------------------------------------------------- // SLIR lowering pass +// This Pass will lower SLIR to MLIR's standard dialect. struct SLIRToMLIRPass : public mlir::PassWrapper> { @@ -137,6 +148,8 @@ struct SLIRToMLIRPass mlir::ModuleOp getModule(); }; +// Mark what dialects we need for this pass. It's basically translate to what +// dialects do we want to lower to void SLIRToMLIRPass::getDependentDialects( mlir::DialectRegistry ®istry) const { registry.insert(); @@ -156,21 +169,24 @@ void SLIRToMLIRPass::runOnModule() { mlir::ConversionTarget target(getContext()); // We define the specific operations, or dialects, that are legal targets for - // this lowering. In our case, we are lowering to a combination of the - // `Affine`, `MemRef` and `Standard` dialects. + // this lowering. In our case, we are lowering to the `Standard` dialects. target.addLegalDialect(); - // We also define the Toy dialect as Illegal so that the conversion will fail - // if any of these operations are *not* converted. Given that we actually want - // a partial lowering, we explicitly mark the Toy operations that don't want - // to lower, `toy.print`, as `legal`. + // We also define the SLIR dialect as Illegal so that the conversion will fail + // if any of these operations are *not* converted. target.addIllegalDialect(); + + // Mark operations that are LEGAL for this pass. It means that we don't lower + // them is this pass but we will in another pass. So we don't want to get + // an error since we are not lowering them. // target.addLegalOp(); target.addLegalOp(); // Now that the conversion target has been defined, we just need to provide - // the set of patterns that will lower the Toy operations. + // the set of patterns that will lower the SLIR operations. mlir::RewritePatternSet patterns(&getContext()); + + // Pattern to lower ValueOp and FnOp patterns.add(&getContext()); // With the target and rewrite patterns defined, we can now attempt the diff --git a/src/libserene/passes/to_llvm_dialect.cpp b/src/libserene/passes/to_llvm_dialect.cpp index e8ab097..b6cf1af 100644 --- a/src/libserene/passes/to_llvm_dialect.cpp +++ b/src/libserene/passes/to_llvm_dialect.cpp @@ -37,7 +37,7 @@ struct SLIRToLLVMDialect : public mlir::PassWrapper> { void getDependentDialects(mlir::DialectRegistry ®istry) const override { - registry.insert(); + registry.insert(); } void runOnOperation() final; @@ -67,8 +67,6 @@ void SLIRToLLVMDialect::runOnOperation() { // set of legal ones. mlir::RewritePatternSet patterns(&getContext()); - // mlir::populateAffineToStdConversionPatterns(patterns); - // populateLoopToStdConversionPatterns(patterns); populateStdToLLVMConversionPatterns(typeConverter, patterns); // patterns.add(&getContext());