Skip to content

Commit a9ef59b

Browse files
author
Flax Authors
committed
Merge pull request #4641 from google:nnx-improve-variable
PiperOrigin-RevId: 738988788
2 parents fa0f3e8 + 6a2b33e commit a9ef59b

File tree

4 files changed

+143
-43
lines changed

4 files changed

+143
-43
lines changed

flax/nnx/variablelib.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -397,90 +397,148 @@ def __contains__(self, item) -> bool:
397397
return item in self.value # type: ignore
398398

399399
def __add__(self, other) -> A:
400+
if isinstance(other, Variable):
401+
other = other.value
400402
return self.value.__add__(other) # type: ignore
401403

402404
def __sub__(self, other) -> A:
405+
if isinstance(other, Variable):
406+
other = other.value
403407
return self.value.__sub__(other) # type: ignore
404408

405409
def __mul__(self, other) -> A:
410+
if isinstance(other, Variable):
411+
other = other.value
406412
return self.value.__mul__(other) # type: ignore
407413

408414
def __matmul__(self, other) -> A:
415+
if isinstance(other, Variable):
416+
other = other.value
409417
return self.value.__matmul__(other) # type: ignore
410418

411419
def __truediv__(self, other) -> A:
420+
if isinstance(other, Variable):
421+
other = other.value
412422
return self.value.__truediv__(other) # type: ignore
413423

414424
def __floordiv__(self, other) -> A:
425+
if isinstance(other, Variable):
426+
other = other.value
415427
return self.value.__floordiv__(other) # type: ignore
416428

417429
def __mod__(self, other) -> A:
430+
if isinstance(other, Variable):
431+
other = other.value
418432
return self.value.__mod__(other) # type: ignore
419433

420434
def __divmod__(self, other) -> A:
435+
if isinstance(other, Variable):
436+
other = other.value
421437
return self.value.__divmod__(other) # type: ignore
422438

423439
def __pow__(self, other) -> A:
440+
if isinstance(other, Variable):
441+
other = other.value
424442
return self.value.__pow__(other) # type: ignore
425443

426444
def __lshift__(self, other) -> A:
445+
if isinstance(other, Variable):
446+
other = other.value
427447
return self.value.__lshift__(other) # type: ignore
428448

429449
def __rshift__(self, other) -> A:
450+
if isinstance(other, Variable):
451+
other = other.value
430452
return self.value.__rshift__(other) # type: ignore
431453

432454
def __and__(self, other) -> A:
455+
if isinstance(other, Variable):
456+
other = other.value
433457
return self.value.__and__(other) # type: ignore
434458

435459
def __xor__(self, other) -> A:
460+
if isinstance(other, Variable):
461+
other = other.value
436462
return self.value.__xor__(other) # type: ignore
437463

438464
def __or__(self, other) -> A:
465+
if isinstance(other, Variable):
466+
other = other.value
439467
return self.value.__or__(other) # type: ignore
440468

441469
def __radd__(self, other) -> A:
470+
if isinstance(other, Variable):
471+
other = other.value
442472
return self.value.__radd__(other) # type: ignore
443473

444474
def __rsub__(self, other) -> A:
475+
if isinstance(other, Variable):
476+
other = other.value
445477
return self.value.__rsub__(other) # type: ignore
446478

447479
def __rmul__(self, other) -> A:
480+
if isinstance(other, Variable):
481+
other = other.value
448482
return self.value.__rmul__(other) # type: ignore
449483

450484
def __rmatmul__(self, other) -> A:
485+
if isinstance(other, Variable):
486+
other = other.value
451487
return self.value.__rmatmul__(other) # type: ignore
452488

453489
def __rtruediv__(self, other) -> A:
490+
if isinstance(other, Variable):
491+
other = other.value
454492
return self.value.__rtruediv__(other) # type: ignore
455493

456494
def __rfloordiv__(self, other) -> A:
495+
if isinstance(other, Variable):
496+
other = other.value
457497
return self.value.__rfloordiv__(other) # type: ignore
458498

459499
def __rmod__(self, other) -> A:
500+
if isinstance(other, Variable):
501+
other = other.value
460502
return self.value.__rmod__(other) # type: ignore
461503

462504
def __rdivmod__(self, other) -> A:
505+
if isinstance(other, Variable):
506+
other = other.value
463507
return self.value.__rdivmod__(other) # type: ignore
464508

465509
def __rpow__(self, other) -> A:
510+
if isinstance(other, Variable):
511+
other = other.value
466512
return self.value.__rpow__(other) # type: ignore
467513

468514
def __rlshift__(self, other) -> A:
515+
if isinstance(other, Variable):
516+
other = other.value
469517
return self.value.__rlshift__(other) # type: ignore
470518

471519
def __rrshift__(self, other) -> A:
520+
if isinstance(other, Variable):
521+
other = other.value
472522
return self.value.__rrshift__(other) # type: ignore
473523

474524
def __rand__(self, other) -> A:
525+
if isinstance(other, Variable):
526+
other = other.value
475527
return self.value.__rand__(other) # type: ignore
476528

477529
def __rxor__(self, other) -> A:
530+
if isinstance(other, Variable):
531+
other = other.value
478532
return self.value.__rxor__(other) # type: ignore
479533

480534
def __ror__(self, other) -> A:
535+
if isinstance(other, Variable):
536+
other = other.value
481537
return self.value.__ror__(other) # type: ignore
482538

483539
def __iadd__(self: V, other) -> V:
540+
if isinstance(other, Variable):
541+
other = other.value
484542
value = self.value
485543
if hasattr(value, '__iadd__'):
486544
value.__iadd__(other)
@@ -489,6 +547,8 @@ def __iadd__(self: V, other) -> V:
489547
return self
490548

491549
def __isub__(self: V, other) -> V:
550+
if isinstance(other, Variable):
551+
other = other.value
492552
value = self.value
493553
if hasattr(value, '__isub__'):
494554
value.__isub__(other)
@@ -497,6 +557,8 @@ def __isub__(self: V, other) -> V:
497557
return self
498558

499559
def __imul__(self: V, other) -> V:
560+
if isinstance(other, Variable):
561+
other = other.value
500562
value = self.value
501563
if hasattr(value, '__imul__'):
502564
value.__imul__(other)
@@ -505,6 +567,8 @@ def __imul__(self: V, other) -> V:
505567
return self
506568

507569
def __imatmul__(self: V, other) -> V:
570+
if isinstance(other, Variable):
571+
other = other.value
508572
value = self.value
509573
if hasattr(value, '__imatmul__'):
510574
value.__imatmul__(other)
@@ -513,6 +577,8 @@ def __imatmul__(self: V, other) -> V:
513577
return self
514578

515579
def __itruediv__(self: V, other) -> V:
580+
if isinstance(other, Variable):
581+
other = other.value
516582
value = self.value
517583
if hasattr(value, '__itruediv__'):
518584
value.__itruediv__(other)
@@ -521,6 +587,8 @@ def __itruediv__(self: V, other) -> V:
521587
return self
522588

523589
def __ifloordiv__(self: V, other) -> V:
590+
if isinstance(other, Variable):
591+
other = other.value
524592
value = self.value
525593
if hasattr(value, '__ifloordiv__'):
526594
value.__ifloordiv__(other)
@@ -529,6 +597,8 @@ def __ifloordiv__(self: V, other) -> V:
529597
return self
530598

531599
def __imod__(self: V, other) -> V:
600+
if isinstance(other, Variable):
601+
other = other.value
532602
value = self.value
533603
if hasattr(value, '__imod__'):
534604
value.__imod__(other)
@@ -537,6 +607,8 @@ def __imod__(self: V, other) -> V:
537607
return self
538608

539609
def __ipow__(self: V, other) -> V:
610+
if isinstance(other, Variable):
611+
other = other.value
540612
value = self.value
541613
if hasattr(value, '__ipow__'):
542614
value.__ipow__(other)
@@ -545,6 +617,8 @@ def __ipow__(self: V, other) -> V:
545617
return self
546618

547619
def __ilshift__(self: V, other) -> V:
620+
if isinstance(other, Variable):
621+
other = other.value
548622
value = self.value
549623
if hasattr(value, '__ilshift__'):
550624
value.__ilshift__(other)
@@ -553,6 +627,8 @@ def __ilshift__(self: V, other) -> V:
553627
return self
554628

555629
def __irshift__(self: V, other) -> V:
630+
if isinstance(other, Variable):
631+
other = other.value
556632
value = self.value
557633
if hasattr(value, '__irshift__'):
558634
value.__irshift__(other)
@@ -561,6 +637,8 @@ def __irshift__(self: V, other) -> V:
561637
return self
562638

563639
def __iand__(self: V, other) -> V:
640+
if isinstance(other, Variable):
641+
other = other.value
564642
value = self.value
565643
if hasattr(value, '__iand__'):
566644
value.__iand__(other)
@@ -569,6 +647,8 @@ def __iand__(self: V, other) -> V:
569647
return self
570648

571649
def __ixor__(self: V, other) -> V:
650+
if isinstance(other, Variable):
651+
other = other.value
572652
value = self.value
573653
if hasattr(value, '__ixor__'):
574654
value.__ixor__(other)
@@ -577,6 +657,8 @@ def __ixor__(self: V, other) -> V:
577657
return self
578658

579659
def __ior__(self: V, other) -> V:
660+
if isinstance(other, Variable):
661+
other = other.value
580662
value = self.value
581663
if hasattr(value, '__ior__'):
582664
value.__ior__(other)

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,8 @@ filterwarnings = [
185185
"ignore:.*divide by zero encountered in.*:RuntimeWarning",
186186
# DeprecationWarning: numpy.core is deprecated
187187
"ignore:.*numpy.core is deprecated.*:DeprecationWarning",
188+
# DeprecationWarning: shape requires ndarray or scalar arguments
189+
"ignore:.*shape requires ndarray or scalar arguments.*:DeprecationWarning",
188190
]
189191

190192
[tool.coverage.report]

tests/nnx/variable_test.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,18 @@ def __call__(self, x):
8181

8282
self.assertEqual(result, 6)
8383

84+
def test_binary_ops(self):
85+
v1 = nnx.Param(2)
86+
v2 = nnx.Param(3)
87+
88+
result = v1 + v2
89+
90+
self.assertEqual(result, 5)
91+
92+
v1 += v2
93+
94+
self.assertEqual(v1.value, 5)
95+
8496

8597
if __name__ == '__main__':
8698
absltest.main()

0 commit comments

Comments
 (0)