/* -*- C++ -*- * Serene Programming Language * * Copyright (c) 2019-2022 Sameer Rahmani * * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, version 2. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program. If not, see . */ #include "serene/context.h" #include "serene/conventions.h" #include "serene/passes.h" #include "serene/slir/dialect.h" #include "serene/slir/type_converter.h" #include "serene/utils.h" #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace ll = mlir::LLVM; namespace serene::passes { static ll::GlobalOp getOrCreateInternalString(mlir::Location loc, mlir::OpBuilder &builder, llvm::StringRef name, llvm::StringRef value, mlir::ModuleOp module) { // Create the global at the entry of the module. ll::GlobalOp global; if (!(global = module.lookupSymbol(name))) { mlir::OpBuilder::InsertionGuard insertGuard(builder); builder.setInsertionPointToStart(module.getBody()); auto type = ll::LLVMArrayType::get( mlir::IntegerType::get(builder.getContext(), I8_SIZE), value.size()); // TODO: Do we want link once ? global = builder.create(loc, type, /*isConstant=*/true, ll::Linkage::Linkonce, name, builder.getStringAttr(value), /*alignment=*/0); } return global; }; static mlir::Value getPtrToInternalString(mlir::OpBuilder &builder, ll::GlobalOp global) { auto loc = global.getLoc(); auto I8 = mlir::IntegerType::get(builder.getContext(), I8_SIZE); // Get the pointer to the first character in the global string. mlir::Value globalPtr = builder.create(loc, global); mlir::Value cst0 = builder.create( loc, mlir::IntegerType::get(builder.getContext(), I64_SIZE), builder.getIntegerAttr(builder.getIndexType(), 0)); return builder.create(loc, ll::LLVMPointerType::get(I8), globalPtr, llvm::ArrayRef({cst0})); }; static ll::GlobalOp getOrCreateString(mlir::Location loc, mlir::OpBuilder &builder, llvm::StringRef name, llvm::StringRef value, uint32_t len, mlir::ModuleOp module) { auto *ctx = builder.getContext(); ll::GlobalOp global; if (!(global = module.lookupSymbol(name))) { mlir::OpBuilder::InsertionGuard insertGuard(builder); builder.setInsertionPointToStart(module.getBody()); mlir::Attribute initValue{}; auto type = slir::getStringTypeinLLVM(*ctx); global = builder.create( loc, type, /*isConstant=*/true, ll::Linkage::Linkonce, name, initValue); auto &gr = global.getInitializerRegion(); auto *block = builder.createBlock(&gr); builder.setInsertionPoint(block, block->begin()); mlir::Value structInstant = builder.create(loc, type); auto strOp = getOrCreateInternalString(loc, builder, name, value, module); auto ptrToStr = getPtrToInternalString(builder, strOp); auto length = builder.create( loc, mlir::IntegerType::get(ctx, I32_SIZE), builder.getI32IntegerAttr(len)); // Setting the string pointer field structInstant = builder.create( loc, structInstant.getType(), structInstant, ptrToStr, builder.getI64ArrayAttr(0)); // Setting the len field structInstant = builder.create( loc, structInstant.getType(), structInstant, length, builder.getI64ArrayAttr(1)); builder.create(loc, structInstant); } return global; }; static ll::GlobalOp getOrCreateSymbol(mlir::Location loc, mlir::OpBuilder &builder, llvm::StringRef ns, llvm::StringRef name, mlir::ModuleOp module) { std::string fqName; ll::GlobalOp global; auto *ctx = builder.getContext(); auto symName = serene::mangleInternalSymName(fqName); makeFQSymbolName(ns, name, fqName); if (!(global = module.lookupSymbol(symName))) { mlir::OpBuilder::InsertionGuard insertGuard(builder); builder.setInsertionPointToStart(module.getBody()); mlir::Attribute initValue{}; auto type = slir::getSymbolTypeinLLVM(*ctx); // We want to allow merging the strings representing the ns or name part // of the symbol with other modules to unify them. ll::Linkage linkage = ll::Linkage::Linkonce; global = builder.create(loc, type, /*isConstant=*/true, linkage, symName, initValue); auto &gr = global.getInitializerRegion(); auto *block = builder.createBlock(&gr); builder.setInsertionPoint(block, block->begin()); mlir::Value structInstant = builder.create(loc, type); // We want to use the mangled ns as the name of the constant that // holds the ns string auto mangledNSName = serene::mangleInternalStringName(ns); // The globalop that we want to use for the ns field auto nsField = getOrCreateString(loc, builder, mangledNSName, ns, ns.size(), module); auto ptrToNs = builder.create(loc, nsField); // We want to use the mangled 'name' as the name of the constant that // holds the 'name' string auto mangledName = serene::mangleInternalStringName(name); // The global op to use as the 'name' field auto nameField = getOrCreateString(loc, builder, mangledName, name, name.size(), module); auto ptrToName = builder.create(loc, nameField); // Setting the string pointer field structInstant = builder.create( loc, structInstant.getType(), structInstant, ptrToNs, builder.getI64ArrayAttr(0)); // Setting the len field structInstant = builder.create( loc, structInstant.getType(), structInstant, ptrToName, builder.getI64ArrayAttr(0)); builder.create(loc, structInstant); } return global; }; struct LowerSymbol : public mlir::OpConversionPattern { using OpConversionPattern::OpConversionPattern; mlir::LogicalResult matchAndRewrite(serene::slir::SymbolOp op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override; }; mlir::LogicalResult LowerSymbol::matchAndRewrite(serene::slir::SymbolOp op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const { auto ns = op.ns(); auto name = op.name(); auto loc = op.getLoc(); auto module = op->getParentOfType(); // If there is no use for the result of this op then simply erase it // if (!op.getResult().use_empty()) { // rewriter.eraseOp(op); // return mlir::success(); // } auto global = getOrCreateSymbol(loc, rewriter, ns, name, module); rewriter.eraseOp(op); (void)adaptor; (void)global; return mlir::success(); } struct LowerDefine : public mlir::OpConversionPattern { using OpConversionPattern::OpConversionPattern; mlir::LogicalResult matchAndRewrite(serene::slir::DefineOp op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override; }; mlir::LogicalResult LowerDefine::matchAndRewrite(serene::slir::DefineOp op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const { (void)rewriter; (void)adaptor; auto value = op.value(); auto *valueop = value.getDefiningOp(); auto maybeTopLevel = op.is_top_level(); bool isTopLevel = false; if (maybeTopLevel) { isTopLevel = *maybeTopLevel; } // If the value than we bind a name to is a constant, rewrite to // `define_constant` // TODO: Replace the isConstantLike with a `hasTrait` call if (mlir::detail::isConstantLike(valueop)) { mlir::Attribute constantValue; if (!mlir::matchPattern(value, mlir::m_Constant(&constantValue))) { PASS_LOG( "Failure: The constant like op don't have a constant attribute."); return mlir::failure(); } rewriter.replaceOpWithNewOp( op, op.sym_name(), constantValue, rewriter.getBoolAttr(isTopLevel), op.sym_visibilityAttr()); // TODO: Erase the valueop if it has no other 'use' in the IR // rewriter.eraseOp(valueop); return mlir::success(); } // If the value was a Function literal (like an anonymous function) // rewrite to a Func.FuncOp if (mlir::isa(valueop)) { rewriter.eraseOp(op); return mlir::success(); } // TODO: [lib] If we're building an executable `linkonce` is a good choice // but for a library we need to choose a better link type ll::Linkage linkage = ll::Linkage::Linkonce; auto loc = op.getLoc(); auto moduleOp = op->getParentOfType(); auto ns = moduleOp.getNameAttr(); auto name = op.getName(); mlir::Attribute initAttr{}; std::string fqsym; makeFQSymbolName(ns.getValue(), name, fqsym); if (!isTopLevel) { auto llvmType = typeConverter->convertType(value.getType()); { mlir::PatternRewriter::InsertionGuard insertGuard(rewriter); auto moduleOp = op->getParentOfType(); auto &topLevelRegion = moduleOp.getBodyRegion(); auto &moduleBlock = topLevelRegion.getBlocks(); rewriter.setInsertionPointToStart(&moduleBlock.front()); auto globalOp = rewriter.create(loc, llvmType, /*isConstant=*/false, linkage, fqsym, initAttr); auto &gr = globalOp.getInitializerRegion(); auto *block = rewriter.createBlock(&gr); if (block == nullptr) { // TODO: use diagnastics llvm::errs() << "Faild to create block of the globalOp!"; return mlir::failure(); } rewriter.setInsertionPointToStart(block); auto undef = rewriter.create(loc, llvmType); rewriter.create(loc, undef.getResult()); } rewriter.setInsertionPointAfter(op); auto symRef = mlir::SymbolRefAttr::get(rewriter.getContext(), fqsym); // auto llvmValue = typeConverter->materializeTargetConversion( // rewriter, loc, llvmType, value); // llvm::outs() << ">>> " << symRef << "|" << llvmValue << "|" << op << // "\n"; rewriter.replaceOpWithNewOp(op, symRef, value); // auto setvalOp = rewriter.create(loc, symRef, // llvmValue); rewriter.insert(setvalOp); rewriter.eraseOp(op); return mlir::success(); } // auto globop = rewriter.create(loc, value.getType(), // /*isConstant=*/false, // linkage, fqsym, // initAttr); // auto &gr = globop.getInitializerRegion(); // auto *block = rewriter.createBlock(&gr); // block->addArgument(value.getType(), value.getLoc()); // rewriter.setInsertionPoint(block, block->begin()); // rewriter.create(value.getLoc(), // adaptor.getOperands()); // if (!op.getResult().use_empty()) { // auto symValue = rewriter.create(loc, ns, name); // rewriter.replaceOp(op, symValue.getResult()); // } rewriter.eraseOp(op); return mlir::success(); } struct LowerDefineConstant : public mlir::OpConversionPattern { using OpConversionPattern::OpConversionPattern; mlir::LogicalResult matchAndRewrite(serene::slir::DefineConstantOp op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override; }; mlir::LogicalResult LowerDefineConstant::matchAndRewrite( serene::slir::DefineConstantOp op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const { (void)rewriter; (void)adaptor; auto value = op.value(); auto name = op.getName(); auto loc = op.getLoc(); auto moduleOp = op->getParentOfType(); auto ns = moduleOp.getNameAttr(); std::string fqsym; makeFQSymbolName(ns.getValue(), name, fqsym); // TODO: [lib] If we're building an executable `linkonce` is a good choice // but for a library we need to choose a better link type ll::Linkage linkage = ll::Linkage::Linkonce; // TODO: use ll::ConstantOp instead UNUSED(rewriter.create(loc, value.getType(), /*isConstant=*/true, linkage, fqsym, value)); // if (!op.value().use_empty()) { // auto symValue = rewriter.create(loc, ns, name); // rewriter.replaceOp(op, symValue.getResult()); // } rewriter.eraseOp(op); return mlir::success(); } #define GEN_PASS_CLASSES #include "serene/passes/passes.h.inc" class LowerSLIR : public LowerSLIRBase { void runOnOperation() override { mlir::ModuleOp module = getOperation(); // The first thing to define is the conversion target. This will define the // final target for this lowering. mlir::ConversionTarget target(getContext()); slir::TypeConverter typeConverter(getContext()); // We define the specific operations, or dialects, that are legal targets // for this lowering. In our case, we are lowering to the `Standard` // dialects. target.addLegalDialect(); target.addLegalDialect(); target.addLegalDialect(); // We also define the SLIR dialect as Illegal so that the conversion will // fail if any of these operations are *not* converted. target.addIllegalDialect(); // Mark operations that are LEGAL for this pass. It means that we don't // lower them is this pass but we will in another pass. So we don't want to // get an error since we are not lowering them. // target.addLegalOp(); target.addLegalOp(); // Now that the conversion target has been defined, we just need to provide // the set of patterns that will lower the SLIR operations. mlir::RewritePatternSet patterns(&getContext()); // Pattern to lower ValueOp and FnOp // LowerDefineConstant patterns.add(typeConverter, &getContext()); // With the target and rewrite patterns defined, we can now attempt the // conversion. The conversion will signal failure if any of our `illegal` // operations were not converted successfully. if (failed(applyPartialConversion(module, target, std::move(patterns)))) { signalPassFailure(); } } }; std::unique_ptr createLowerSLIR() { return std::make_unique(); } #define GEN_PASS_REGISTRATION #include "serene/passes/passes.h.inc" // ---------------------------------------------------------------------------- // ValueOp lowering to constant op struct ValueOpLowering : public mlir::OpRewritePattern { using OpRewritePattern::OpRewritePattern; mlir::LogicalResult matchAndRewrite(serene::slir::Value1Op op, mlir::PatternRewriter &rewriter) const final; }; mlir::LogicalResult ValueOpLowering::matchAndRewrite(serene::slir::Value1Op op, mlir::PatternRewriter &rewriter) const { auto value = op.value(); mlir::Location loc = op.getLoc(); llvm::SmallVector arg_types(0); auto func_type = rewriter.getFunctionType(arg_types, rewriter.getI64Type()); // TODO: use a mechanism to generate unique names auto fn = rewriter.create(loc, "randomname", func_type); auto *entryBlock = fn.addEntryBlock(); rewriter.setInsertionPointToStart(entryBlock); // Since we only support i64 at the moment we use ConstantOp auto retVal = rewriter .create(loc, (int64_t)value, rewriter.getI64Type()) .getResult(); UNUSED(rewriter.create(loc, retVal)); fn.setPrivate(); // Erase the original ValueOP rewriter.eraseOp(op); return mlir::success(); } // ---------------------------------------------------------------------------- // Fn lowering pattern struct FnOpLowering : public mlir::OpRewritePattern { using OpRewritePattern::OpRewritePattern; mlir::LogicalResult matchAndRewrite(serene::slir::Fn1Op op, mlir::PatternRewriter &rewriter) const final; }; mlir::LogicalResult FnOpLowering::matchAndRewrite(serene::slir::Fn1Op op, mlir::PatternRewriter &rewriter) const { auto args = op.args(); auto name = op.name(); auto isPublic = op.sym_visibility().getValueOr("public") == "public"; mlir::Location loc = op.getLoc(); llvm::SmallVector arg_types; for (const auto &arg : args) { auto attr = arg.getValue().dyn_cast(); if (!attr) { op.emitError("It's not a type attr"); return mlir::failure(); } arg_types.push_back(attr.getValue()); } auto func_type = rewriter.getFunctionType(arg_types, rewriter.getI64Type()); auto fn = rewriter.create(loc, name, func_type); auto *entryBlock = fn.addEntryBlock(); rewriter.setInsertionPointToStart(entryBlock); auto retVal = rewriter .create(loc, (int64_t)3, rewriter.getI64Type()) .getResult(); rewriter.create(loc, retVal); if (!isPublic) { fn.setPrivate(); } rewriter.eraseOp(op); return mlir::success(); } // ---------------------------------------------------------------------------- // SLIR lowering pass // This Pass will lower SLIR to MLIR's standard dialect. struct SLIRToMLIRPass : public mlir::PassWrapper> { void getDependentDialects(mlir::DialectRegistry ®istry) const override; void runOnOperation() final; void runOnModule(); mlir::ModuleOp getModule(); }; // Mark what dialects we need for this pass. It's basically translate to what // dialects do we want to lower to void SLIRToMLIRPass::getDependentDialects( mlir::DialectRegistry ®istry) const { registry.insert(); }; /// Return the current function being transformed. mlir::ModuleOp SLIRToMLIRPass::getModule() { return this->getOperation(); } void SLIRToMLIRPass::runOnOperation() { runOnModule(); } void SLIRToMLIRPass::runOnModule() { auto module = getModule(); // The first thing to define is the conversion target. This will define the // final target for this lowering. mlir::ConversionTarget target(getContext()); // We define the specific operations, or dialects, that are legal targets for // this lowering. In our case, we are lowering to the `Standard` dialects. target.addLegalDialect(); target.addLegalDialect(); // We also define the SLIR dialect as Illegal so that the conversion will fail // if any of these operations are *not* converted. target.addIllegalDialect(); // Mark operations that are LEGAL for this pass. It means that we don't lower // them is this pass but we will in another pass. So we don't want to get // an error since we are not lowering them. // target.addLegalOp(); target.addLegalOp(); // Now that the conversion target has been defined, we just need to provide // the set of patterns that will lower the SLIR operations. mlir::RewritePatternSet patterns(&getContext()); // Pattern to lower ValueOp and FnOp patterns.add(&getContext()); // With the target and rewrite patterns defined, we can now attempt the // conversion. The conversion will signal failure if any of our `illegal` // operations were not converted successfully. if (failed(applyPartialConversion(module, target, std::move(patterns)))) { signalPassFailure(); } }; std::unique_ptr createSLIRLowerToMLIRPass() { return std::make_unique(); }; void registerAllPasses() { registerPasses(); } } // namespace serene::passes