Replies: 5 comments 2 replies
-
(This code snippet came from @hawkinsp ) |
Beta Was this translation helpful? Give feedback.
-
Let's wait to add more HOWTOs until we switch over the HOWTO system on to the Linen examples |
Beta Was this translation helpful? Give feedback.
-
I gather this probably isn't intended to be part of the public API, but I noticed on CPU the estimated flop count for |
Beta Was this translation helpful? Give feedback.
-
|
Beta Was this translation helpful? Give feedback.
-
JAX now has this import jax
@jax.jit
def f(x, y):
return 2 * x + y
# Query for cost analysis, print FLOP estimate
x, y = 3, 4
compiled = f.lower(x, y).compile()
flops = compiled.cost_analysis()[0]['flops']
print(flops) # 2.0 |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
For now, here is how you do it:
(This may become a JAX API later)
Beta Was this translation helpful? Give feedback.
All reactions