Skip to content

Commit fdca87a

Browse files
committed
fix: explicitly signal "delete" operation in modifyValue
1 parent 48b9a98 commit fdca87a

File tree

5 files changed

+46
-51
lines changed

5 files changed

+46
-51
lines changed

diff.go

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -81,11 +81,11 @@ func diffNode[T HamtValue[T]](ctx context.Context, pre, cur *Node[T], depth int)
8181
if prePointer.Link == curPointer.Link {
8282
continue
8383
}
84-
preChild, err := prePointer.loadChild(ctx, pre.store, pre.bitWidth, pre.hash, pre.zeroValue)
84+
preChild, err := prePointer.loadChild(ctx, pre.store, pre.bitWidth, pre.hash)
8585
if err != nil {
8686
return nil, err
8787
}
88-
curChild, err := curPointer.loadChild(ctx, cur.store, cur.bitWidth, cur.hash, pre.zeroValue)
88+
curChild, err := curPointer.loadChild(ctx, cur.store, cur.bitWidth, cur.hash)
8989
if err != nil {
9090
return nil, err
9191
}
@@ -99,7 +99,7 @@ func diffNode[T HamtValue[T]](ctx context.Context, pre, cur *Node[T], depth int)
9999

100100
// check if KV's from cur exists in any children of pre's child.
101101
if prePointer.isShard() && !curPointer.isShard() {
102-
childKV, err := prePointer.loadChildKVs(ctx, pre.store, pre.bitWidth, pre.hash, pre.zeroValue)
102+
childKV, err := prePointer.loadChildKVs(ctx, pre.store, pre.bitWidth, pre.hash)
103103
if err != nil {
104104
return nil, err
105105
}
@@ -109,7 +109,7 @@ func diffNode[T HamtValue[T]](ctx context.Context, pre, cur *Node[T], depth int)
109109

110110
// check if KV's from pre exists in any children of cur's child.
111111
if !prePointer.isShard() && curPointer.isShard() {
112-
childKV, err := curPointer.loadChildKVs(ctx, cur.store, cur.bitWidth, cur.hash, pre.zeroValue)
112+
childKV, err := curPointer.loadChildKVs(ctx, cur.store, cur.bitWidth, cur.hash)
113113
if err != nil {
114114
return nil, err
115115
}
@@ -125,7 +125,7 @@ func diffNode[T HamtValue[T]](ctx context.Context, pre, cur *Node[T], depth int)
125125
pointer := pre.getPointer(byte(pre.indexForBitPos(idx)))
126126

127127
if pointer.isShard() {
128-
child, err := pointer.loadChild(ctx, pre.store, pre.bitWidth, pre.hash, pre.zeroValue)
128+
child, err := pointer.loadChild(ctx, pre.store, pre.bitWidth, pre.hash)
129129
if err != nil {
130130
return nil, err
131131
}
@@ -149,7 +149,7 @@ func diffNode[T HamtValue[T]](ctx context.Context, pre, cur *Node[T], depth int)
149149
pointer := cur.getPointer(byte(cur.indexForBitPos(idx)))
150150

151151
if pointer.isShard() {
152-
child, err := pointer.loadChild(ctx, pre.store, pre.bitWidth, pre.hash, pre.zeroValue)
152+
child, err := pointer.loadChild(ctx, pre.store, pre.bitWidth, pre.hash)
153153
if err != nil {
154154
return nil, err
155155
}
@@ -174,7 +174,7 @@ func diffNode[T HamtValue[T]](ctx context.Context, pre, cur *Node[T], depth int)
174174
return changes, nil
175175
}
176176

177-
func diffKVs[T HamtValue[T]](pre, cur []*KV[T], idx int) []*Change[T] {
177+
func diffKVs[T HamtValue[T]](pre, cur []*KV[T], _ int) []*Change[T] {
178178
preMap := make(map[string]T, len(pre))
179179
curMap := make(map[string]T, len(cur))
180180
var changes []*Change[T]
@@ -220,7 +220,7 @@ func diffKVs[T HamtValue[T]](pre, cur []*KV[T], idx int) []*Change[T] {
220220
return changes
221221
}
222222

223-
func addAll[T HamtValue[T]](ctx context.Context, node *Node[T], idx int) ([]*Change[T], error) {
223+
func addAll[T HamtValue[T]](ctx context.Context, node *Node[T], _ int) ([]*Change[T], error) {
224224
var changes []*Change[T]
225225
if err := node.ForEach(ctx, func(k string, val T) error {
226226
changes = append(changes, &Change[T]{
@@ -237,7 +237,7 @@ func addAll[T HamtValue[T]](ctx context.Context, node *Node[T], idx int) ([]*Cha
237237
return changes, nil
238238
}
239239

240-
func removeAll[T HamtValue[T]](ctx context.Context, node *Node[T], idx int) ([]*Change[T], error) {
240+
func removeAll[T HamtValue[T]](ctx context.Context, node *Node[T], _ int) ([]*Change[T], error) {
241241
var changes []*Change[T]
242242
if err := node.ForEach(ctx, func(k string, val T) error {
243243
changes = append(changes, &Change[T]{

diff_parallel.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -193,11 +193,11 @@ func (s *diffScheduler[T]) work(ctx context.Context, todo *task[T], results chan
193193
if prePointer.Link == curPointer.Link {
194194
return nil
195195
}
196-
preChild, err := prePointer.loadChild(ctx, pre.store, pre.bitWidth, pre.hash, pre.zeroValue)
196+
preChild, err := prePointer.loadChild(ctx, pre.store, pre.bitWidth, pre.hash)
197197
if err != nil {
198198
return err
199199
}
200-
curChild, err := curPointer.loadChild(ctx, cur.store, cur.bitWidth, cur.hash, pre.zeroValue)
200+
curChild, err := curPointer.loadChild(ctx, cur.store, cur.bitWidth, cur.hash)
201201
if err != nil {
202202
return err
203203
}
@@ -220,15 +220,15 @@ func (s *diffScheduler[T]) work(ctx context.Context, todo *task[T], results chan
220220

221221
// check if KV's from cur exists in any children of pre's child.
222222
case prePointer.isShard() && !curPointer.isShard():
223-
childKV, err := prePointer.loadChildKVs(ctx, pre.store, pre.bitWidth, pre.hash, pre.zeroValue)
223+
childKV, err := prePointer.loadChildKVs(ctx, pre.store, pre.bitWidth, pre.hash)
224224
if err != nil {
225225
return err
226226
}
227227
parallelDiffKVs(childKV, curPointer.KVs, results)
228228

229229
// check if KV's from pre exists in any children of cur's child.
230230
case !prePointer.isShard() && curPointer.isShard():
231-
childKV, err := curPointer.loadChildKVs(ctx, cur.store, cur.bitWidth, cur.hash, pre.zeroValue)
231+
childKV, err := curPointer.loadChildKVs(ctx, cur.store, cur.bitWidth, cur.hash)
232232
if err != nil {
233233
return err
234234
}
@@ -243,7 +243,7 @@ func (s *diffScheduler[T]) work(ctx context.Context, todo *task[T], results chan
243243
pointer := pre.getPointer(byte(pre.indexForBitPos(idx)))
244244

245245
if pointer.isShard() {
246-
child, err := pointer.loadChild(ctx, pre.store, pre.bitWidth, pre.hash, pre.zeroValue)
246+
child, err := pointer.loadChild(ctx, pre.store, pre.bitWidth, pre.hash)
247247
if err != nil {
248248
return err
249249
}
@@ -266,7 +266,7 @@ func (s *diffScheduler[T]) work(ctx context.Context, todo *task[T], results chan
266266
pointer := cur.getPointer(byte(cur.indexForBitPos(idx)))
267267

268268
if pointer.isShard() {
269-
child, err := pointer.loadChild(ctx, pre.store, pre.bitWidth, pre.hash, pre.zeroValue)
269+
child, err := pointer.loadChild(ctx, pre.store, pre.bitWidth, pre.hash)
270270
if err != nil {
271271
return err
272272
}

diff_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ func TestBigDiff(t *testing.T) {
340340
}
341341
}
342342

343-
func diffAndAssertLength(ctx context.Context, t *testing.T, prevBs, curBs cbor.IpldStore, a, b *Node[*CborByteArray], expectedLength int) []*Change[*CborByteArray] {
343+
func diffAndAssertLength(ctx context.Context, t *testing.T, _, _ cbor.IpldStore, a, b *Node[*CborByteArray], expectedLength int) []*Change[*CborByteArray] {
344344
if err := a.Flush(ctx); err != nil {
345345
t.Fatal(err)
346346
}

hamt.go

Lines changed: 28 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,6 @@ type Node[T HamtValue[T]] struct {
106106

107107
// for fetching and storing children
108108
store cbor.IpldStore
109-
110-
zeroValue T
111109
}
112110

113111
// Pointer is an element in a HAMT node's Pointers array, encoded as an IPLD
@@ -200,7 +198,7 @@ func NewNode[T HamtValue[T]](cs cbor.IpldStore, options ...Option) (*Node[T], er
200198
}
201199
}
202200

203-
return newNode[T](cs, cfg.hashFn, cfg.bitWidth, zero[T]()), nil
201+
return newNode[T](cs, cfg.hashFn, cfg.bitWidth), nil
204202
}
205203

206204
// Find navigates through the HAMT structure to where key `k` should exist. If
@@ -234,19 +232,18 @@ func (n *Node[T]) Find(ctx context.Context, k string) (T, bool, error) {
234232
// further nodes.
235233
func (n *Node[T]) Delete(ctx context.Context, k string) (bool, error) {
236234
kb := []byte(k)
237-
modified, err := n.modifyValue(ctx, &hashBits{b: n.hash(kb)}, kb, n.zeroValue, OVERWRITE)
235+
modified, err := n.modifyValue(ctx, &hashBits{b: n.hash(kb)}, kb, zero[T](), true, OVERWRITE)
238236
return modified == MODIFIED, err
239237
}
240238

241239
// Constructs a new node value.
242-
func newNode[T HamtValue[T]](cs cbor.IpldStore, hashFn HashFunc, bitWidth int, zeroValue T) *Node[T] {
240+
func newNode[T HamtValue[T]](cs cbor.IpldStore, hashFn HashFunc, bitWidth int) *Node[T] {
243241
nd := &Node[T]{
244-
Bitfield: big.NewInt(0),
245-
Pointers: make([]*Pointer[T], 0),
246-
bitWidth: bitWidth,
247-
hash: hashFn,
248-
store: cs,
249-
zeroValue: zeroValue,
242+
Bitfield: big.NewInt(0),
243+
Pointers: make([]*Pointer[T], 0),
244+
bitWidth: bitWidth,
245+
hash: hashFn,
246+
store: cs,
250247
}
251248
return nd
252249
}
@@ -281,7 +278,7 @@ func (n *Node[T]) getValue(ctx context.Context, hv *hashBits, k string, cb func(
281278
if c.isShard() {
282279
// if isShard, we have a pointer to a child that we need to load and
283280
// delegate our find operation to
284-
chnd, err := c.loadChild(ctx, n.store, n.bitWidth, n.hash, n.zeroValue)
281+
chnd, err := c.loadChild(ctx, n.store, n.bitWidth, n.hash)
285282
if err != nil {
286283
return err
287284
}
@@ -303,12 +300,12 @@ func (n *Node[T]) getValue(ctx context.Context, hv *hashBits, k string, cb func(
303300

304301
// load a HAMT node from the IpldStore and pass on the (assumed) parameters
305302
// that are not stored with the node.
306-
func (p *Pointer[T]) loadChild(ctx context.Context, ns cbor.IpldStore, bitWidth int, hash HashFunc, zeroValue T) (*Node[T], error) {
303+
func (p *Pointer[T]) loadChild(ctx context.Context, ns cbor.IpldStore, bitWidth int, hash HashFunc) (*Node[T], error) {
307304
if p.cache != nil {
308305
return p.cache, nil
309306
}
310307

311-
out, err := loadNode[T](ctx, ns, p.Link, false, bitWidth, hash, zeroValue)
308+
out, err := loadNode[T](ctx, ns, p.Link, false, bitWidth, hash)
312309
if err != nil {
313310
return nil, err
314311
}
@@ -319,8 +316,8 @@ func (p *Pointer[T]) loadChild(ctx context.Context, ns cbor.IpldStore, bitWidth
319316

320317
// load a HAMT node from the IpldStore passing on the (assumed) parameters
321318
// that are not stored with the node and return all KVs of the child and its children.
322-
func (p *Pointer[T]) loadChildKVs(ctx context.Context, ns cbor.IpldStore, bitWidth int, hash HashFunc, zeroValue T) ([]*KV[T], error) {
323-
child, err := p.loadChild(ctx, ns, bitWidth, hash, zeroValue)
319+
func (p *Pointer[T]) loadChildKVs(ctx context.Context, ns cbor.IpldStore, bitWidth int, hash HashFunc) ([]*KV[T], error) {
320+
child, err := p.loadChild(ctx, ns, bitWidth, hash)
324321
if err != nil {
325322
return nil, err
326323
}
@@ -355,7 +352,7 @@ func LoadNode[T HamtValue[T]](ctx context.Context, cs cbor.IpldStore, c cid.Cid,
355352
return nil, err
356353
}
357354
}
358-
return loadNode[T](ctx, cs, c, true, cfg.bitWidth, cfg.hashFn, zero[T]())
355+
return loadNode[T](ctx, cs, c, true, cfg.bitWidth, cfg.hashFn)
359356
}
360357

361358
// internal version of loadNode that is aware of whether this is a root node or
@@ -367,7 +364,6 @@ func loadNode[T HamtValue[T]](
367364
isRoot bool,
368365
bitWidth int,
369366
hashFunction HashFunc,
370-
zeroValue T,
371367
) (*Node[T], error) {
372368
var out Node[T]
373369
if err := cs.Get(ctx, c, &out); err != nil {
@@ -377,7 +373,6 @@ func loadNode[T HamtValue[T]](
377373
out.store = cs
378374
out.bitWidth = bitWidth
379375
out.hash = hashFunction
380-
out.zeroValue = zeroValue
381376

382377
// Validation
383378

@@ -457,7 +452,7 @@ func (n *Node[T]) checkSize(ctx context.Context) (uint64, error) {
457452
totsize := uint64(buf.Len())
458453
for _, ch := range n.Pointers {
459454
if ch.isShard() {
460-
chnd, err := ch.loadChild(ctx, n.store, n.bitWidth, n.hash, n.zeroValue)
455+
chnd, err := ch.loadChild(ctx, n.store, n.bitWidth, n.hash)
461456
if err != nil {
462457
return 0, err
463458
}
@@ -533,7 +528,7 @@ func (n *Node[T]) Flush(ctx context.Context) error {
533528
// and save the resulting CID wherever you expect the HAMT root to persist.
534529
func (n *Node[T]) Set(ctx context.Context, k string, v T) error {
535530
keyBytes := []byte(k)
536-
_, err := n.modifyValue(ctx, &hashBits{b: n.hash(keyBytes)}, keyBytes, v, OVERWRITE)
531+
_, err := n.modifyValue(ctx, &hashBits{b: n.hash(keyBytes)}, keyBytes, v, false, OVERWRITE)
537532
return err
538533
}
539534

@@ -542,7 +537,7 @@ func (n *Node[T]) Set(ctx context.Context, k string, v T) error {
542537
// false otherwise.
543538
func (n *Node[T]) SetIfAbsent(ctx context.Context, k string, v T) (bool, error) {
544539
keyBytes := []byte(k)
545-
modified, err := n.modifyValue(ctx, &hashBits{b: n.hash(keyBytes)}, keyBytes, v, NOVERWRITE)
540+
modified, err := n.modifyValue(ctx, &hashBits{b: n.hash(keyBytes)}, keyBytes, v, false, NOVERWRITE)
546541
return bool(modified), err
547542
}
548543

@@ -626,7 +621,7 @@ func (n *Node[T]) cleanChild(chnd *Node[T], cindex byte) error {
626621
// cleanNode()). Recursive calls use the same arguments on child nodes but
627622
// note that `hv.Next()` is not idempotent. Each call will increment the number
628623
// of bits chomped off the hash digest for this key.
629-
func (n *Node[T]) modifyValue(ctx context.Context, hv *hashBits, k []byte, v T, replace overwrite) (modified, error) {
624+
func (n *Node[T]) modifyValue(ctx context.Context, hv *hashBits, k []byte, v T, delete bool, replace overwrite) (modified, error) {
630625
idx, err := hv.Next(n.bitWidth)
631626
if err != nil {
632627
return UNMODIFIED, ErrMaxDepth
@@ -636,7 +631,7 @@ func (n *Node[T]) modifyValue(ctx context.Context, hv *hashBits, k []byte, v T,
636631
// doesn't exist in the HAMT already and can insert it at the appropriate
637632
// position.
638633
if n.Bitfield.Bit(idx) != 1 {
639-
if n.zeroValue.Equals(v) { // Delete absent key
634+
if delete { // Delete absent key
640635
return UNMODIFIED, nil
641636
}
642637
return MODIFIED, n.insertKV(idx, k, v)
@@ -656,12 +651,12 @@ func (n *Node[T]) modifyValue(ctx context.Context, hv *hashBits, k []byte, v T,
656651
// it is an eventual Flush passing back over this "cache" node which
657652
// causes the updates made to the in-memory "cache" node to eventually
658653
// be persisted.
659-
chnd, err := child.loadChild(ctx, n.store, n.bitWidth, n.hash, n.zeroValue)
654+
chnd, err := child.loadChild(ctx, n.store, n.bitWidth, n.hash)
660655
if err != nil {
661656
return UNMODIFIED, err
662657
}
663658

664-
modified, err := chnd.modifyValue(ctx, hv, k, v, replace)
659+
modified, err := chnd.modifyValue(ctx, hv, k, v, delete, replace)
665660
if err != nil {
666661
return UNMODIFIED, err
667662
}
@@ -676,7 +671,7 @@ func (n *Node[T]) modifyValue(ctx context.Context, hv *hashBits, k []byte, v T,
676671
// current data it contains. This may involve collapsing child nodes if
677672
// they no longer contain enough elements to justify their stand-alone
678673
// existence.
679-
if n.zeroValue.Equals(v) {
674+
if delete {
680675
if err := n.cleanChild(chnd, cindex); err != nil {
681676
return UNMODIFIED, err
682677
}
@@ -689,7 +684,7 @@ func (n *Node[T]) modifyValue(ctx context.Context, hv *hashBits, k []byte, v T,
689684
// modified (or deleted) here or needs to be added as a new child node if
690685
// there is an overflow.
691686

692-
if n.zeroValue.Equals(v) {
687+
if delete {
693688
// delete operation, find the child and remove it, compacting the bucket in
694689
// the process
695690
for i, p := range child.KVs {
@@ -721,15 +716,15 @@ func (n *Node[T]) modifyValue(ctx context.Context, hv *hashBits, k []byte, v T,
721716
if len(child.KVs) >= bucketSize {
722717
// bucket is full, create a child node (shard) with all existing bucket
723718
// elements plus the new one and set it in the place of the bucket
724-
sub := newNode[T](n.store, n.hash, n.bitWidth, n.zeroValue)
719+
sub := newNode[T](n.store, n.hash, n.bitWidth)
725720
hvcopy := &hashBits{b: hv.b, consumed: hv.consumed}
726-
if _, err := sub.modifyValue(ctx, hvcopy, k, v, replace); err != nil {
721+
if _, err := sub.modifyValue(ctx, hvcopy, k, v, delete, replace); err != nil {
727722
return UNMODIFIED, err
728723
}
729724

730725
for _, p := range child.KVs {
731726
chhv := &hashBits{b: n.hash(p.Key), consumed: hv.consumed}
732-
if _, err := sub.modifyValue(ctx, chhv, p.Key, p.Value, replace); err != nil {
727+
if _, err := sub.modifyValue(ctx, chhv, p.Key, p.Value, delete, replace); err != nil {
733728
return UNMODIFIED, err
734729
}
735730
}
@@ -802,7 +797,7 @@ func (n *Node[T]) getPointer(i byte) *Pointer[T] {
802797
// as cached nodes.
803798
func (n *Node[T]) Copy() *Node[T] {
804799
// TODO(rvagg): clarify what situations this method is actually useful for.
805-
nn := newNode[T](n.store, n.hash, n.bitWidth, n.zeroValue)
800+
nn := newNode[T](n.store, n.hash, n.bitWidth)
806801
nn.Bitfield.Set(n.Bitfield)
807802
nn.Pointers = make([]*Pointer[T], len(n.Pointers))
808803

@@ -838,7 +833,7 @@ func (p *Pointer[T]) isShard() bool {
838833
func (n *Node[T]) ForEach(ctx context.Context, f func(k string, val T) error) error {
839834
for _, p := range n.Pointers {
840835
if p.isShard() {
841-
chnd, err := p.loadChild(ctx, n.store, n.bitWidth, n.hash, n.zeroValue)
836+
chnd, err := p.loadChild(ctx, n.store, n.bitWidth, n.hash)
842837
if err != nil {
843838
return err
844839
}

hamt_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -441,7 +441,7 @@ func printHamt(hamt *Node[*CborByteArray]) {
441441
fmt.Printf("%s‣ %v:\n", strings.Repeat(" ", depth), c)
442442
for _, p := range n.Pointers {
443443
if p.isShard() {
444-
child, err := p.loadChild(ctx, n.store, n.bitWidth, n.hash, n.zeroValue)
444+
child, err := p.loadChild(ctx, n.store, n.bitWidth, n.hash)
445445
if err != nil {
446446
panic(err)
447447
}
@@ -503,7 +503,7 @@ func statsrec(n *Node[*CborByteArray], st *hamtStats) {
503503
st.totalNodes++
504504
for _, p := range n.Pointers {
505505
if p.isShard() {
506-
nd, err := p.loadChild(context.Background(), n.store, n.bitWidth, n.hash, n.zeroValue)
506+
nd, err := p.loadChild(context.Background(), n.store, n.bitWidth, n.hash)
507507
if err != nil {
508508
panic(err)
509509
}

0 commit comments

Comments
 (0)