@@ -815,11 +815,19 @@ static LogicalResult setRootConfigForSoftmaxCopyPipeline(
815
815
mlir::FunctionOpInterface entryPointFn, linalg::SoftmaxOp softmaxOp,
816
816
AMDAIEDevice targetDevice, uint32_t numRows, uint32_t numCols,
817
817
std::string enableAMDAIEUkernels) {
818
- // For now, we are targeting a single column of cores, and the tile sizes are
819
- // hardcoded. We don't tile the reduction dim as the softmax op is not a pure
820
- // reduction op.
818
+ // For now, we are targeting a single column of cores, and the L1 tile sizes
819
+ // are hardcoded. We don't tile the reduction dim as the softmax op is not a
820
+ // pure reduction op.
821
+ ArrayRef<int64_t > inputShape = softmaxOp.getInput ().getType ().getShape ();
822
+ int64_t m1Tile = std::min<int64_t >(inputShape[0 ], 32 );
823
+ int64_t m0Tile = std::min<int64_t >(inputShape[0 ], numRows * m1Tile);
824
+
825
+ SmallVector<int64_t > tileSizeLevel0 = {m0Tile, 0 };
826
+ SmallVector<int64_t > tileSizeLevel1 = {m1Tile, 0 };
827
+ SmallVector<int64_t > tileSizeLevel2 = {0 , 0 };
821
828
if (failed (setOpConfigAndEntryPointFnTranslation (
822
- entryPointFn, softmaxOp, TileSizesListType{{128 , 0 }, {32 , 0 }, {0 , 0 }},
829
+ entryPointFn, softmaxOp,
830
+ TileSizesListType{tileSizeLevel0, tileSizeLevel1, tileSizeLevel2},
823
831
IREE::Codegen::DispatchLoweringPassPipeline::Custom))) {
824
832
return failure ();
825
833
}
0 commit comments