Prevent a race condition in findHighestSharedBlockHash where we get headersSelectedTip and then pass it as highHash to GetBlockLocator, without locking consensus (#1410)

* Prevent a race condition in findHighestSharedBlockHash where we get headersSelectedTip and then pass it as highHash to GetBlockLocator, without locking consensus

* Restart findHighestSharedBlockHash if lowHash or highHash are no longer in selectedParentChain

* Test for specifically ErrBlockNotInSelectedParentChain instead of database NotFound error

* Fix TestCreateHeadersSelectedChainBlockLocator

Co-authored-by: Ori Newman <orinewman1@gmail.com>
This commit is contained in:
Svarog 2021-01-13 17:55:37 +02:00 committed by GitHub
parent 61be80a60c
commit 1b97cfb302
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 125 additions and 51 deletions

View File

@ -3,6 +3,8 @@ package blockrelay
import (
"time"
"github.com/kaspanet/kaspad/domain/consensus/model"
"github.com/kaspanet/kaspad/app/appmessage"
"github.com/kaspanet/kaspad/app/protocol/common"
"github.com/kaspanet/kaspad/app/protocol/protocolerrors"
@ -120,59 +122,28 @@ func (flow *handleRelayInvsFlow) syncHeaders(highHash *externalapi.DomainHash) e
}
func (flow *handleRelayInvsFlow) findHighestSharedBlockHash(targetHash *externalapi.DomainHash) (*externalapi.DomainHash, error) {
lowHash, err := flow.Domain().Consensus().PruningPoint()
if err != nil {
return nil, err
}
highHash, err := flow.Domain().Consensus().GetHeadersSelectedTip()
log.Debugf("Sending a blockLocator to %s between pruning point and headers selected tip", flow.peer)
blockLocator, err := flow.Domain().Consensus().CreateFullHeadersSelectedChainBlockLocator()
if err != nil {
return nil, err
}
for !lowHash.Equal(highHash) {
log.Debugf("Sending a blockLocator to %s between %s and %s", flow.peer, lowHash, highHash)
blockLocator, err := flow.Domain().Consensus().CreateHeadersSelectedChainBlockLocator(lowHash, highHash)
for {
highestHash, err := flow.fetchHighestHash(targetHash, blockLocator)
if err != nil {
return nil, err
}
highestHashIndex, err := flow.findHighestHashIndex(highestHash, blockLocator)
if err != nil {
return nil, err
}
ibdBlockLocatorMessage := appmessage.NewMsgIBDBlockLocator(targetHash, blockLocator)
err = flow.outgoingRoute.Enqueue(ibdBlockLocatorMessage)
if err != nil {
return nil, err
}
message, err := flow.dequeueIncomingMessageAndSkipInvs(common.DefaultTimeout)
if err != nil {
return nil, err
}
ibdBlockLocatorHighestHashMessage, ok := message.(*appmessage.MsgIBDBlockLocatorHighestHash)
if !ok {
return nil, protocolerrors.Errorf(true, "received unexpected message type. "+
"expected: %s, got: %s", appmessage.CmdIBDBlockLocatorHighestHash, message.Command())
}
highestHash := ibdBlockLocatorHighestHashMessage.HighestHash
log.Debugf("The highest hash the peer %s knows is %s", flow.peer, highestHash)
if highestHashIndex == 0 ||
// If the block locator contains only two adjacent chain blocks, the
// syncer will always find the same highest chain block, so to avoid
// an endless loop, we explicitly stop the loop in such situation.
(len(blockLocator) == 2 && highestHashIndex == 1) {
highestHashIndex := 0
highestHashIndexFound := false
for i, blockLocatorHash := range blockLocator {
if highestHash.Equal(blockLocatorHash) {
highestHashIndex = i
highestHashIndexFound = true
break
}
}
if !highestHashIndexFound {
return nil, protocolerrors.Errorf(true, "highest hash %s "+
"returned from peer %s is not in the original blockLocator", highestHash, flow.peer)
}
log.Debugf("The index of the highest hash in the original "+
"blockLocator sent to %s is %d", flow.peer, highestHashIndex)
// If the block locator contains only two adjacent chain blocks, the
// syncer will always find the same highest chain block, so to avoid
// an endless loop, we explicitly stop the loop in such situation.
if len(blockLocator) == 2 && highestHashIndex == 1 {
return highestHash, nil
}
@ -180,10 +151,75 @@ func (flow *handleRelayInvsFlow) findHighestSharedBlockHash(targetHash *external
if highestHashIndex > 0 {
locatorHashAboveHighestHash = blockLocator[highestHashIndex-1]
}
highHash = locatorHashAboveHighestHash
lowHash = highestHash
blockLocator, err = flow.nextBlockLocator(highestHash, locatorHashAboveHighestHash)
if err != nil {
return nil, err
}
}
return highHash, nil
}
func (flow *handleRelayInvsFlow) nextBlockLocator(lowHash, highHash *externalapi.DomainHash) (externalapi.BlockLocator, error) {
log.Debugf("Sending a blockLocator to %s between %s and %s", flow.peer, lowHash, highHash)
blockLocator, err := flow.Domain().Consensus().CreateHeadersSelectedChainBlockLocator(lowHash, highHash)
if err != nil {
if errors.Is(model.ErrBlockNotInSelectedParentChain, err) {
return nil, err
}
log.Debugf("Headers selected parent chain moved since findHighestSharedBlockHash - " +
"restarting with full block locator")
blockLocator, err = flow.Domain().Consensus().CreateFullHeadersSelectedChainBlockLocator()
if err != nil {
return nil, err
}
}
return blockLocator, nil
}
func (flow *handleRelayInvsFlow) findHighestHashIndex(
highestHash *externalapi.DomainHash, blockLocator externalapi.BlockLocator) (int, error) {
highestHashIndex := 0
highestHashIndexFound := false
for i, blockLocatorHash := range blockLocator {
if highestHash.Equal(blockLocatorHash) {
highestHashIndex = i
highestHashIndexFound = true
break
}
}
if !highestHashIndexFound {
return 0, protocolerrors.Errorf(true, "highest hash %s "+
"returned from peer %s is not in the original blockLocator", highestHash, flow.peer)
}
log.Debugf("The index of the highest hash in the original "+
"blockLocator sent to %s is %d", flow.peer, highestHashIndex)
return highestHashIndex, nil
}
func (flow *handleRelayInvsFlow) fetchHighestHash(
targetHash *externalapi.DomainHash, blockLocator externalapi.BlockLocator) (*externalapi.DomainHash, error) {
ibdBlockLocatorMessage := appmessage.NewMsgIBDBlockLocator(targetHash, blockLocator)
err := flow.outgoingRoute.Enqueue(ibdBlockLocatorMessage)
if err != nil {
return nil, err
}
message, err := flow.dequeueIncomingMessageAndSkipInvs(common.DefaultTimeout)
if err != nil {
return nil, err
}
ibdBlockLocatorHighestHashMessage, ok := message.(*appmessage.MsgIBDBlockLocatorHighestHash)
if !ok {
return nil, protocolerrors.Errorf(true, "received unexpected message type. "+
"expected: %s, got: %s", appmessage.CmdIBDBlockLocatorHighestHash, message.Command())
}
highestHash := ibdBlockLocatorHighestHashMessage.HighestHash
log.Debugf("The highest hash the peer %s knows is %s", flow.peer, highestHash)
return highestHash, nil
}
func (flow *handleRelayInvsFlow) downloadHeaders(highestSharedBlockHash *externalapi.DomainHash,

View File

@ -1,9 +1,10 @@
package consensus
import (
"github.com/kaspanet/kaspad/infrastructure/db/database"
"sync"
"github.com/kaspanet/kaspad/infrastructure/db/database"
"github.com/kaspanet/kaspad/domain/consensus/model"
"github.com/kaspanet/kaspad/domain/consensus/model/externalapi"
"github.com/kaspanet/kaspad/domain/consensus/ruleerrors"
@ -298,6 +299,23 @@ func (s *consensus) CreateBlockLocator(lowHash, highHash *externalapi.DomainHash
return s.syncManager.CreateBlockLocator(lowHash, highHash, limit)
}
func (s *consensus) CreateFullHeadersSelectedChainBlockLocator() (externalapi.BlockLocator, error) {
s.lock.Lock()
defer s.lock.Unlock()
lowHash, err := s.pruningStore.PruningPoint(s.databaseContext)
if err != nil {
return nil, err
}
highHash, err := s.headersSelectedTipStore.HeadersSelectedTip(s.databaseContext)
if err != nil {
return nil, err
}
return s.syncManager.CreateHeadersSelectedChainBlockLocator(lowHash, highHash)
}
func (s *consensus) CreateHeadersSelectedChainBlockLocator(lowHash,
highHash *externalapi.DomainHash) (externalapi.BlockLocator, error) {
s.lock.Lock()

View File

@ -0,0 +1,7 @@
package model
import "github.com/pkg/errors"
// ErrBlockNotInSelectedParentChain is returned from CreateHeadersSelectedChainBlockLocator if one of the parameters
// passed to it are not in the headers selected parent chain
var ErrBlockNotInSelectedParentChain = errors.New("Block is not in selected parent chain")

View File

@ -19,6 +19,7 @@ type Consensus interface {
GetVirtualSelectedParent() (*DomainHash, error)
CreateBlockLocator(lowHash, highHash *DomainHash, limit uint32) (BlockLocator, error)
CreateHeadersSelectedChainBlockLocator(lowHash, highHash *DomainHash) (BlockLocator, error)
CreateFullHeadersSelectedChainBlockLocator() (BlockLocator, error)
GetSyncInfo() (*SyncInfo, error)
Tips() ([]*DomainHash, error)
GetVirtualInfo() (*VirtualInfo, error)

View File

@ -1,7 +1,9 @@
package syncmanager
import (
"github.com/kaspanet/kaspad/domain/consensus/model"
"github.com/kaspanet/kaspad/domain/consensus/model/externalapi"
"github.com/kaspanet/kaspad/infrastructure/db/database"
"github.com/pkg/errors"
)
@ -74,11 +76,19 @@ func (sm *syncManager) createHeadersSelectedChainBlockLocator(lowHash,
lowHashIndex, err := sm.headersSelectedChainStore.GetIndexByHash(sm.databaseContext, lowHash)
if err != nil {
if database.IsNotFoundError(err) {
return nil, errors.Wrapf(model.ErrBlockNotInSelectedParentChain,
"LowHash %s is not in selected parent chain", lowHash)
}
return nil, err
}
highHashIndex, err := sm.headersSelectedChainStore.GetIndexByHash(sm.databaseContext, highHash)
if err != nil {
if database.IsNotFoundError(err) {
return nil, errors.Wrapf(model.ErrBlockNotInSelectedParentChain,
"LowHash %s is not in selected parent chain", lowHash)
}
return nil, err
}

View File

@ -1,14 +1,16 @@
package syncmanager_test
import (
"strings"
"testing"
"github.com/kaspanet/kaspad/domain/consensus"
"github.com/kaspanet/kaspad/domain/consensus/model"
"github.com/kaspanet/kaspad/domain/consensus/model/externalapi"
"github.com/kaspanet/kaspad/domain/consensus/utils/testutils"
"github.com/kaspanet/kaspad/domain/dagconfig"
"github.com/kaspanet/kaspad/infrastructure/db/database"
"github.com/pkg/errors"
"strings"
"testing"
)
func TestCreateBlockLocator(t *testing.T) {
@ -224,7 +226,7 @@ func TestCreateHeadersSelectedChainBlockLocator(t *testing.T) {
// Check block locator with non chain blocks
_, err = tc.CreateHeadersSelectedChainBlockLocator(params.GenesisHash, sideChainTipHash)
if !errors.Is(err, database.ErrNotFound) {
if !errors.Is(err, model.ErrBlockNotInSelectedParentChain) {
t.Fatalf("expected error '%s' but got '%s'", database.ErrNotFound, err)
}
})