Skip to content

feat(experiment): add a parallel HAMT traversal function #103

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
203 changes: 194 additions & 9 deletions hamt.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@ import (
"bytes"
"context"
"fmt"
"golang.org/x/sync/errgroup"
"math/big"
"sort"
"sync"

cid "github.com/ipfs/go-cid"
cbor "github.com/ipfs/go-ipld-cbor"
Expand Down Expand Up @@ -372,16 +374,27 @@ func loadNode(
out.bitWidth = bitWidth
out.hash = hashFunction

if err := validateNode(out, isRoot); err != nil {
return nil, err
}
return &out, nil
}

// validates a node
func validateNode(
out Node,
isRoot bool,
) error {
// Validation

// too many elements in the data array for the configured bitWidth?
if len(out.Pointers) > 1<<uint(out.bitWidth) {
return nil, ErrMalformedHamt
return ErrMalformedHamt
}

// the bifield is lying or the elements array is
if out.bitsSetCount() != len(out.Pointers) {
return nil, ErrMalformedHamt
return ErrMalformedHamt
}

for _, ch := range out.Pointers {
Expand All @@ -390,18 +403,18 @@ func loadNode(
if isLink == isBucket {
// Pointer#UnmarshalCBOR shouldn't allow this
// A node can only be one of link or bucket
return nil, ErrMalformedHamt
return ErrMalformedHamt
}
if isLink && ch.Link.Type() != cid.DagCBOR { // not dag-cbor
return nil, ErrMalformedHamt
return ErrMalformedHamt
}
if isBucket {
if len(ch.KVs) == 0 || len(ch.KVs) > bucketSize {
return nil, ErrMalformedHamt
return ErrMalformedHamt
}
for i := 1; i < len(ch.KVs); i++ {
if bytes.Compare(ch.KVs[i-1].Key, ch.KVs[i].Key) >= 0 {
return nil, ErrMalformedHamt
return ErrMalformedHamt
}
}
}
Expand All @@ -410,16 +423,16 @@ func loadNode(
if !isRoot {
// the only valid empty node is a root node
if len(out.Pointers) == 0 {
return nil, ErrMalformedHamt
return ErrMalformedHamt
}
// a non-root node that contains <=bucketSize direct elements should not
// exist under compaction rules
if out.directChildCount() == 0 && out.directKVCount() <= bucketSize {
return nil, ErrMalformedHamt
return ErrMalformedHamt
}
}

return &out, nil
return nil
}

// checkSize computes the total serialized size of the entire HAMT.
Expand Down Expand Up @@ -877,3 +890,175 @@ func (n *Node) ForEach(ctx context.Context, f func(k string, val *cbg.Deferred)
}
return nil
}

// ForEachParallel calls function f on each k / val pair found in the HAMT.
// This performs a full traversal of the graph and for large HAMTs can cause
// a large number of loads from the underlying store.
// The values are returned as raw bytes, not decoded.
// Unlike ForEach this runs in parallel so passed callbacks should not conflict with each other
func (n *Node) ForEachParallel(ctx context.Context, f func(k string, val *cbg.Deferred) error) error {
return parallelShardWalk(ctx, n, f)
}

type OptionalInteger struct {
Value int
Error error
}

// TODO: This interface is obviously wrong, but we may want some GetMany-style grouping or even a "session" object
// we can push CIDs into to leverage efficiencies without tons of goroutines
type getManyIPLDStore interface {
GetMany(ctx context.Context, cids []cid.Cid, outs []interface{}) <-chan *OptionalInteger
}

type listCidsAndShards struct {
cids []cid.Cid
shards []*Node
}

func (n *Node) walkChildren(f func(k string, val *cbg.Deferred) error) (*listCidsAndShards, error) {
res := &listCidsAndShards{}

for _, p := range n.Pointers {
if p.isShard() {
if p.cache != nil {
res.shards = append(res.shards, p.cache)
} else {
res.cids = append(res.cids, p.Link)
}
} else {
for _, kv := range p.KVs {
if err := f(string(kv.Key), kv.Value); err != nil {
return nil, err
}
}
}
}

return res, nil
}

// parallelShardWalk walks the HAMT concurrently processing callbacks upon encountering leaf nodes
func parallelShardWalk(ctx context.Context, root *Node, processShardValues func(k string, val *cbg.Deferred) error) error {
const concurrency = 16 // TODO: should be an option, also this number was basically made up with a bit of empirical testing/usage
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be easy enough to make this an arg to ForEachParallel, yeah? Doing so would match the behavior of ParallelDiff.


var visitlk sync.Mutex
visitSet := cid.NewSet()
visit := visitSet.Visit

// Setup synchronization
grp, errGrpCtx := errgroup.WithContext(ctx)

// Input and output queues for workers.
feed := make(chan *listCidsAndShards)
out := make(chan *listCidsAndShards)
done := make(chan struct{})

for i := 0; i < concurrency; i++ {
grp.Go(func() error {
for feedChildren := range feed {
for _, nextShard := range feedChildren.shards {
nextChildren, err := nextShard.walkChildren(processShardValues)
if err != nil {
return err
}

select {
case out <- nextChildren:
case <-errGrpCtx.Done():
return nil
}
}

var linksToVisit []cid.Cid
for _, nextCid := range feedChildren.cids {
var shouldVisit bool

visitlk.Lock()
shouldVisit = visit(nextCid)
visitlk.Unlock()

if shouldVisit {
linksToVisit = append(linksToVisit, nextCid)
}
}

// TODO: allow for Pointer caching
dserv := root.store.(getManyIPLDStore)
nodes := make([]interface{}, len(linksToVisit))
for i := 0; i < len(linksToVisit); i++ {
nodes[i] = new(Node)
}
chNodes := dserv.GetMany(errGrpCtx, linksToVisit, nodes)
for optNode := range chNodes {
if optNode.Error != nil {
return optNode.Error
}
nextShard := nodes[optNode.Value].(*Node)
nextShard.store = root.store
nextShard.bitWidth = root.bitWidth
nextShard.hash = root.hash
if err := validateNode(*nextShard, false); err != nil {
return err
}

nextChildren, err := nextShard.walkChildren(processShardValues)
if err != nil {
return err
}

select {
case out <- nextChildren:
case <-errGrpCtx.Done():
return nil
}
}

select {
case done <- struct{}{}:
case <-errGrpCtx.Done():
}
}
return nil
})
}

send := feed
var todoQueue []*listCidsAndShards
var inProgress int

next := &listCidsAndShards{
shards: []*Node{root},
}

dispatcherLoop:
for {
select {
case send <- next:
inProgress++
if len(todoQueue) > 0 {
next = todoQueue[0]
todoQueue = todoQueue[1:]
} else {
next = nil
send = nil
}
case <-done:
inProgress--
if inProgress == 0 && next == nil {
break dispatcherLoop
}
case nextNodes := <-out:
if next == nil {
next = nextNodes
send = feed
} else {
todoQueue = append(todoQueue, nextNodes)
}
case <-errGrpCtx.Done():
break dispatcherLoop
}
}
close(feed)
return grp.Wait()
}