From 1a947c46b2d8a8bb86be0d2cae696f29c6a6f17f Mon Sep 17 00:00:00 2001
From: Dave Collins <davec@conformal.com>
Date: Fri, 18 Aug 2017 02:12:49 -0500
Subject: [PATCH] blockchain: Faster chain view block locator.

This exposes the ability to more efficiently create a block locator from
a chain view for a given block node by using their ability to do O(1)
lookups.

It also adds tests to ensure the behavior is correct.
---
 blockchain/chainview.go      | 82 ++++++++++++++++++++++++++++++++++++
 blockchain/chainview_test.go | 59 ++++++++++++++++++++++++++
 2 files changed, 141 insertions(+)

diff --git a/blockchain/chainview.go b/blockchain/chainview.go
index 3830cfbef..92daa7fd1 100644
--- a/blockchain/chainview.go
+++ b/blockchain/chainview.go
@@ -320,3 +320,85 @@ func (c *chainView) FindFork(node *blockNode) *blockNode {
 	c.mtx.Unlock()
 	return fork
 }
+
+// blockLocator returns a block locator for the passed block node.  The passed
+// node can be nil in which case the block locator for the current tip
+// associated with the view will be returned.  This only differs from the
+// exported version in that it is up to the caller to ensure the lock is held.
+//
+// See the exported BlockLocator function comments for more details.
+//
+// This function MUST be called with the view mutex locked (for reads).
+func (c *chainView) blockLocator(node *blockNode) BlockLocator {
+	// Use the current tip if requested.
+	if node == nil {
+		node = c.tip()
+	}
+	if node == nil {
+		return nil
+	}
+
+	// Calculate the max number of entries that will ultimately be in the
+	// block locator.  See the description of the algorithm for how these
+	// numbers are derived.
+	var maxEntries uint8
+	if node.height <= 12 {
+		maxEntries = uint8(node.height) + 1
+	} else {
+		// Requested hash itself + previous 10 entries + genesis block.
+		// Then floor(log2(height-10)) entries for the skip portion.
+		adjustedHeight := uint32(node.height) - 10
+		maxEntries = 12 + fastLog2Floor(adjustedHeight)
+	}
+	locator := make(BlockLocator, 0, maxEntries)
+
+	step := int32(1)
+	for node != nil {
+		locator = append(locator, &node.hash)
+
+		// Nothing more to add once the genesis block has been added.
+		if node.height == 0 {
+			break
+		}
+
+		// Calculate height of previous node to include ensuring the
+		// final node is the genesis block.
+		height := node.height - step
+		if height < 0 {
+			height = 0
+		}
+
+		// When the node is in the current chain view, all of its
+		// ancestors must be too, so use a much faster O(1) lookup in
+		// that case.  Otherwise, fall back to walking backwards through
+		// the nodes of the other chain to the correct ancestor.
+		if c.contains(node) {
+			node = c.nodes[height]
+		} else {
+			node = node.Ancestor(height)
+		}
+
+		// Once 11 entries have been included, start doubling the
+		// distance between included hashes.
+		if len(locator) > 10 {
+			step *= 2
+		}
+	}
+
+	return locator
+}
+
+// BlockLocator returns a block locator for the passed block node.  The passed
+// node can be nil in which case the block locator for the current tip
+// associated with the view will be returned.  This only differs from the
+// exported version in that it is up to the caller to ensure the lock is held.
+//
+// See BlockLocator for details on the algorithm used to create a block locator.
+//
+// This function is safe for concurrent access.
+func (c *chainView) BlockLocator(node *blockNode) BlockLocator {
+	c.mtx.Lock()
+	locator := c.blockLocator(node)
+	c.mtx.Unlock()
+	return locator
+}
diff --git a/blockchain/chainview_test.go b/blockchain/chainview_test.go
index 17b7f1f55..964da607b 100644
--- a/blockchain/chainview_test.go
+++ b/blockchain/chainview_test.go
@@ -7,6 +7,7 @@ package blockchain
 import (
 	"fmt"
 	"math/rand"
+	"reflect"
 	"testing"
 
 	"github.com/btcsuite/btcd/wire"
@@ -51,6 +52,27 @@ func tstTip(nodes []*blockNode) *blockNode {
 	return nodes[len(nodes)-1]
 }
 
+// locatorHashes is a convenience function that returns the hashes for all of
+// the passed indexes of the provided nodes.  It is used to construct expected
+// block locators in the tests.
+func locatorHashes(nodes []*blockNode, indexes ...int) BlockLocator {
+	hashes := make(BlockLocator, 0, len(indexes))
+	for _, idx := range indexes {
+		hashes = append(hashes, &nodes[idx].hash)
+	}
+	return hashes
+}
+
+// zipLocators is a convenience function that returns a single block locator
+// given a variable number of them and is used in the tests.
+func zipLocators(locators ...BlockLocator) BlockLocator {
+	var hashes BlockLocator
+	for _, locator := range locators {
+		hashes = append(hashes, locator...)
+	}
+	return hashes
+}
+
 // TestChainView ensures all of the exported functionality of chain views works
 // as intended with the expection of some special cases which are handled in
 // other tests.
@@ -77,6 +99,7 @@ func TestChainView(t *testing.T) {
 		noContains []*blockNode // expected nodes NOT in active view
 		equal      *chainView   // view expected equal to active view
 		unequal    *chainView   // view expected NOT equal to active
+		locator    BlockLocator // expected locator for active view tip
 	}{
 		{
 			// Create a view for branch 0 as the active chain and
@@ -92,6 +115,7 @@ func TestChainView(t *testing.T) {
 			noContains: branch1Nodes,
 			equal:      newChainView(tip(branch0Nodes)),
 			unequal:    newChainView(tip(branch1Nodes)),
+			locator:    locatorHashes(branch0Nodes, 4, 3, 2, 1, 0),
 		},
 		{
 			// Create a view for branch 1 as the active chain and
@@ -107,6 +131,10 @@ func TestChainView(t *testing.T) {
 			noContains: branch2Nodes,
 			equal:      newChainView(tip(branch1Nodes)),
 			unequal:    newChainView(tip(branch2Nodes)),
+			locator: zipLocators(
+				locatorHashes(branch1Nodes, 24, 23, 22, 21, 20,
+					19, 18, 17, 16, 15, 14, 13, 11, 7),
+				locatorHashes(branch0Nodes, 1, 0)),
 		},
 		{
 			// Create a view for branch 2 as the active chain and
@@ -122,6 +150,10 @@ func TestChainView(t *testing.T) {
 			noContains: branch0Nodes[2:],
 			equal:      newChainView(tip(branch2Nodes)),
 			unequal:    newChainView(tip(branch0Nodes)),
+			locator: zipLocators(
+				locatorHashes(branch2Nodes, 2, 1, 0),
+				locatorHashes(branch1Nodes, 0),
+				locatorHashes(branch0Nodes, 1, 0)),
 		},
 	}
 testLoop:
@@ -262,6 +294,15 @@ testLoop:
 				continue testLoop
 			}
 		}
+
+		// Ensure the block locator for the tip of the active view
+		// consists of the expected hashes.
+		locator := test.view.BlockLocator(test.view.tip())
+		if !reflect.DeepEqual(locator, test.locator) {
+			t.Errorf("%s: unexpected locator -- got %v, want %v",
+				test.name, locator, test.locator)
+			continue
+		}
 	}
 }
 
@@ -430,4 +471,22 @@ func TestChainViewNil(t *testing.T) {
 	if fork := view.FindFork(nil); fork != nil {
 		t.Fatalf("FindFork: unexpected fork -- got %v, want nil", fork)
 	}
+
+	// Ensure attempting to get a block locator for the tip doesn't produce
+	// one since the tip is nil.
+	if locator := view.BlockLocator(nil); locator != nil {
+		t.Fatalf("BlockLocator: unexpected locator -- got %v, want nil",
+			locator)
+	}
+
+	// Ensure attempting to get a block locator for a node that exists still
+	// works as intended.
+	branchNodes := chainedNodes(nil, 50)
+	wantLocator := locatorHashes(branchNodes, 49, 48, 47, 46, 45, 44, 43,
+		42, 41, 40, 39, 38, 36, 32, 24, 8, 0)
+	locator := view.BlockLocator(tstTip(branchNodes))
+	if !reflect.DeepEqual(locator, wantLocator) {
+		t.Fatalf("BlockLocator: unexpected locator -- got %v, want %v",
+			locator, wantLocator)
+	}
 }