diff --git a/app/protocol/flows/handshake/handshake.go b/app/protocol/flows/handshake/handshake.go index 515ff5b56..8f305d859 100644 --- a/app/protocol/flows/handshake/handshake.go +++ b/app/protocol/flows/handshake/handshake.go @@ -1,7 +1,6 @@ package handshake import ( - "sync" "sync/atomic" "github.com/kaspanet/kaspad/domain" @@ -16,7 +15,6 @@ import ( "github.com/kaspanet/kaspad/app/appmessage" peerpkg "github.com/kaspanet/kaspad/app/protocol/peer" routerpkg "github.com/kaspanet/kaspad/infrastructure/network/netadapter/router" - "github.com/kaspanet/kaspad/util/locks" "github.com/pkg/errors" ) @@ -38,10 +36,12 @@ func HandleHandshake(context HandleHandshakeContext, netConnection *netadapter.N ) (*peerpkg.Peer, error) { // For HandleHandshake to finish, we need to get from the other node - // a version and verack messages, so we increase the wait group by 2 - // and block HandleHandshake with wg.Wait(). - wg := sync.WaitGroup{} - wg.Add(2) + // a version and verack messages, so we set doneCount to 2, decrease it + // when sending and receiving the version, and close the doneChan when + // it's 0. Then we wait for on select for a tick from doneChan or from + // errChan. + doneCount := int32(2) + doneChan := make(chan struct{}) isStopping := uint32(0) errChan := make(chan error) @@ -56,7 +56,9 @@ func HandleHandshake(context HandleHandshakeContext, netConnection *netadapter.N return } peerAddress = address - wg.Done() + if atomic.AddInt32(&doneCount, -1) == 0 { + close(doneChan) + } }) spawn("HandleHandshake-SendVersion", func() { @@ -65,7 +67,9 @@ func HandleHandshake(context HandleHandshakeContext, netConnection *netadapter.N handleError(err, "SendVersion", &isStopping, errChan) return } - wg.Done() + if atomic.AddInt32(&doneCount, -1) == 0 { + close(doneChan) + } }) select { @@ -74,7 +78,7 @@ func HandleHandshake(context HandleHandshakeContext, netConnection *netadapter.N return nil, err } return nil, nil - case <-locks.ReceiveFromChanWhenDone(func() { wg.Wait() }): + case <-doneChan: } err := context.AddToPeers(peer) diff --git a/testing/integration/64_incoming_connections_test.go b/testing/integration/64_incoming_connections_test.go index edc3055ca..3ac4a8699 100644 --- a/testing/integration/64_incoming_connections_test.go +++ b/testing/integration/64_incoming_connections_test.go @@ -6,8 +6,6 @@ import ( "testing" "time" - "github.com/kaspanet/kaspad/util/locks" - "github.com/kaspanet/kaspad/app/appmessage" ) @@ -56,6 +54,16 @@ func Test64IncomingConnections(t *testing.T) { select { case <-time.After(defaultTimeout): t.Fatalf("Timeout waiting for block added notification from the bullies") - case <-locks.ReceiveFromChanWhenDone(func() { blockAddedWG.Wait() }): + case <-ReceiveFromChanWhenDone(func() { blockAddedWG.Wait() }): } } + +// ReceiveFromChanWhenDone takes a blocking function and returns a channel that sends an empty struct when the function is done. +func ReceiveFromChanWhenDone(callback func()) <-chan struct{} { + ch := make(chan struct{}) + spawn("ReceiveFromChanWhenDone", func() { + callback() + close(ch) + }) + return ch +} diff --git a/testing/integration/ibd_test.go b/testing/integration/ibd_test.go index 1fb26ca67..78b39fc8a 100644 --- a/testing/integration/ibd_test.go +++ b/testing/integration/ibd_test.go @@ -5,8 +5,6 @@ import ( "testing" "time" - "github.com/kaspanet/kaspad/util/locks" - "github.com/kaspanet/kaspad/app/appmessage" ) @@ -33,7 +31,7 @@ func TestIBD(t *testing.T) { select { case <-time.After(defaultTimeout): t.Fatalf("Timeout waiting for IBD to finish. Received %d blocks out of %d", receivedBlocks, numBlocks) - case <-locks.ReceiveFromChanWhenDone(func() { blockAddedWG.Wait() }): + case <-ReceiveFromChanWhenDone(func() { blockAddedWG.Wait() }): } tip1Hash, err := syncer.rpcClient.GetSelectedTipHash() diff --git a/util/locks/receive_from_chan_when_done.go b/util/locks/receive_from_chan_when_done.go deleted file mode 100644 index 759b6c657..000000000 --- a/util/locks/receive_from_chan_when_done.go +++ /dev/null @@ -1,11 +0,0 @@ -package locks - -// ReceiveFromChanWhenDone takes a blocking function and returns a channel that sends an empty struct when the function is done. -func ReceiveFromChanWhenDone(callback func()) <-chan struct{} { - ch := make(chan struct{}) - spawn("ReceiveFromChanWhenDone", func() { - callback() - close(ch) - }) - return ch -}