From a9b3d6bfd710e87529127aac3315c31d94bdddaa Mon Sep 17 00:00:00 2001 From: Adin Schmahmann Date: Tue, 17 Jan 2023 01:15:43 -0500 Subject: [PATCH] feat(experiment): add a parallel HAMT traversal function --- hamt.go | 203 +++++++++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 194 insertions(+), 9 deletions(-) diff --git a/hamt.go b/hamt.go index e5891ff..1c0d291 100644 --- a/hamt.go +++ b/hamt.go @@ -4,8 +4,10 @@ import ( "bytes" "context" "fmt" + "golang.org/x/sync/errgroup" "math/big" "sort" + "sync" cid "github.com/ipfs/go-cid" cbor "github.com/ipfs/go-ipld-cbor" @@ -372,16 +374,27 @@ func loadNode( out.bitWidth = bitWidth out.hash = hashFunction + if err := validateNode(out, isRoot); err != nil { + return nil, err + } + return &out, nil +} + +// validates a node +func validateNode( + out Node, + isRoot bool, +) error { // Validation // too many elements in the data array for the configured bitWidth? if len(out.Pointers) > 1< bucketSize { - return nil, ErrMalformedHamt + return ErrMalformedHamt } for i := 1; i < len(ch.KVs); i++ { if bytes.Compare(ch.KVs[i-1].Key, ch.KVs[i].Key) >= 0 { - return nil, ErrMalformedHamt + return ErrMalformedHamt } } } @@ -410,16 +423,16 @@ func loadNode( if !isRoot { // the only valid empty node is a root node if len(out.Pointers) == 0 { - return nil, ErrMalformedHamt + return ErrMalformedHamt } // a non-root node that contains <=bucketSize direct elements should not // exist under compaction rules if out.directChildCount() == 0 && out.directKVCount() <= bucketSize { - return nil, ErrMalformedHamt + return ErrMalformedHamt } } - return &out, nil + return nil } // checkSize computes the total serialized size of the entire HAMT. @@ -877,3 +890,175 @@ func (n *Node) ForEach(ctx context.Context, f func(k string, val *cbg.Deferred) } return nil } + +// ForEachParallel calls function f on each k / val pair found in the HAMT. +// This performs a full traversal of the graph and for large HAMTs can cause +// a large number of loads from the underlying store. +// The values are returned as raw bytes, not decoded. +// Unlike ForEach this runs in parallel so passed callbacks should not conflict with each other +func (n *Node) ForEachParallel(ctx context.Context, f func(k string, val *cbg.Deferred) error) error { + return parallelShardWalk(ctx, n, f) +} + +type OptionalInteger struct { + Value int + Error error +} + +// TODO: This interface is obviously wrong, but we may want some GetMany-style grouping or even a "session" object +// we can push CIDs into to leverage efficiencies without tons of goroutines +type getManyIPLDStore interface { + GetMany(ctx context.Context, cids []cid.Cid, outs []interface{}) <-chan *OptionalInteger +} + +type listCidsAndShards struct { + cids []cid.Cid + shards []*Node +} + +func (n *Node) walkChildren(f func(k string, val *cbg.Deferred) error) (*listCidsAndShards, error) { + res := &listCidsAndShards{} + + for _, p := range n.Pointers { + if p.isShard() { + if p.cache != nil { + res.shards = append(res.shards, p.cache) + } else { + res.cids = append(res.cids, p.Link) + } + } else { + for _, kv := range p.KVs { + if err := f(string(kv.Key), kv.Value); err != nil { + return nil, err + } + } + } + } + + return res, nil +} + +// parallelShardWalk walks the HAMT concurrently processing callbacks upon encountering leaf nodes +func parallelShardWalk(ctx context.Context, root *Node, processShardValues func(k string, val *cbg.Deferred) error) error { + const concurrency = 16 // TODO: should be an option, also this number was basically made up with a bit of empirical testing/usage + + var visitlk sync.Mutex + visitSet := cid.NewSet() + visit := visitSet.Visit + + // Setup synchronization + grp, errGrpCtx := errgroup.WithContext(ctx) + + // Input and output queues for workers. + feed := make(chan *listCidsAndShards) + out := make(chan *listCidsAndShards) + done := make(chan struct{}) + + for i := 0; i < concurrency; i++ { + grp.Go(func() error { + for feedChildren := range feed { + for _, nextShard := range feedChildren.shards { + nextChildren, err := nextShard.walkChildren(processShardValues) + if err != nil { + return err + } + + select { + case out <- nextChildren: + case <-errGrpCtx.Done(): + return nil + } + } + + var linksToVisit []cid.Cid + for _, nextCid := range feedChildren.cids { + var shouldVisit bool + + visitlk.Lock() + shouldVisit = visit(nextCid) + visitlk.Unlock() + + if shouldVisit { + linksToVisit = append(linksToVisit, nextCid) + } + } + + // TODO: allow for Pointer caching + dserv := root.store.(getManyIPLDStore) + nodes := make([]interface{}, len(linksToVisit)) + for i := 0; i < len(linksToVisit); i++ { + nodes[i] = new(Node) + } + chNodes := dserv.GetMany(errGrpCtx, linksToVisit, nodes) + for optNode := range chNodes { + if optNode.Error != nil { + return optNode.Error + } + nextShard := nodes[optNode.Value].(*Node) + nextShard.store = root.store + nextShard.bitWidth = root.bitWidth + nextShard.hash = root.hash + if err := validateNode(*nextShard, false); err != nil { + return err + } + + nextChildren, err := nextShard.walkChildren(processShardValues) + if err != nil { + return err + } + + select { + case out <- nextChildren: + case <-errGrpCtx.Done(): + return nil + } + } + + select { + case done <- struct{}{}: + case <-errGrpCtx.Done(): + } + } + return nil + }) + } + + send := feed + var todoQueue []*listCidsAndShards + var inProgress int + + next := &listCidsAndShards{ + shards: []*Node{root}, + } + +dispatcherLoop: + for { + select { + case send <- next: + inProgress++ + if len(todoQueue) > 0 { + next = todoQueue[0] + todoQueue = todoQueue[1:] + } else { + next = nil + send = nil + } + case <-done: + inProgress-- + if inProgress == 0 && next == nil { + break dispatcherLoop + } + case nextNodes := <-out: + if next == nil { + next = nextNodes + send = feed + } else { + todoQueue = append(todoQueue, nextNodes) + } + case <-errGrpCtx.Done(): + break dispatcherLoop + } + } + close(feed) + return grp.Wait() +}