Skip to content

[Fix] Enhance floormod simplification rules for better expression matching #17765

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

Ghosts381937
Copy link

Description

Update the floormod simplification rule to correctly handle expressions of the form floormod(c1*x, c2*x) by simplifying them to floormod(c1, c2). This enhancement enables better optimization of expressions that contain common factors, which frequently appear in transformer model computations.

Test Case

from tvm import tir
from tvm.arith import Analyzer
from tvm.tir.op import floormod

# Define symbolic variable
past_decoder_sequence_length = tir.Var("past_decoder_sequence_length", "int64")

# Create expressions with common factor
expr = tir.IntImm("int64", 64) * (past_decoder_sequence_length + tir.IntImm("int64", 1))
divisor = tir.IntImm("int64", 31) * (past_decoder_sequence_length + tir.IntImm("int64", 1))

# Create Analyzer
analyzer = Analyzer()

# Before: returns unsimplified expression
# After: correctly simplifies to 2
print(analyzer.simplify(floormod(expr, divisor)))

@@ -1230,7 +1230,8 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) {
TVM_TRY_REWRITE_IF(floormod(x + y * c1, c2), floormod(x + y * floormod(c1, c2), c2),
c2.Eval()->value > 0);

TVM_TRY_REWRITE_IF(floormod(x * c1, x * c2), x * floormod(c1, c2), c2.Eval()->value != 0);
TVM_TRY_REWRITE_IF(matches_one_of(floormod(x * c1, x * c2), floormod(c1 * x, c2 * x)),
floormod(c1, c2), c2.Eval()->value != 0);
Copy link
Member

@tqchen tqchen Mar 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a bit surprised to see if we need c * x case since canonical simplify will move most cases to x * c, would be good to understand why original flow fails in this case

@tqchen
Copy link
Member

tqchen commented Mar 19, 2025

Thanks for the PR, would be good to learn about your particular usecase(e.g. how did you find out about the case).

How did you construct the expression/model, what if the code is changed explicitly to x * c pattern?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants