[NOD-1001] Make an error in Peer.start() stop the connection process from continuing. (#723)

* [NOD-1001] Move side-effects of connection out of OnVersion

* [NOD-1001] Make AssociateConnection synchronous

* [NOD-1001] Wait for 2 veracks in TestPeerListeners

* [NOD-1001] Made AssociateConnection return error

* [NOD-1001] Remove temporary logs

* [NOD-1001] Fix typos and find-and-replace errors

* [NOD-1001] Move example_test back out of peer package + fix some typos

* [NOD-1001] Use correct remote address in setupPeersWithConns and return to address string literals

* [NOD-1001] Use separate verack channels for inPeer and outPeer

* [NOD-1001] Make verack channels buffered

* [NOD-1001] Removed temporary sleep of 1 second

* [NOD-1001] Removed redundant //
This commit is contained in:
Svarog 2020-05-20 10:36:44 +03:00 committed by GitHub
parent e0f587f599
commit fe25ea3d8c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 225 additions and 146 deletions

View File

@ -9,11 +9,18 @@ import (
"net" "net"
"time" "time"
"github.com/kaspanet/kaspad/dagconfig" "github.com/kaspanet/kaspad/util/daghash"
"github.com/kaspanet/kaspad/peer" "github.com/kaspanet/kaspad/peer"
"github.com/kaspanet/kaspad/dagconfig"
"github.com/kaspanet/kaspad/wire" "github.com/kaspanet/kaspad/wire"
) )
func fakeSelectedTipFn() *daghash.Hash {
return &daghash.Hash{0x12, 0x34}
}
// mockRemotePeer creates a basic inbound peer listening on the simnet port for // mockRemotePeer creates a basic inbound peer listening on the simnet port for
// use with Example_peerConnection. It does not return until the listner is // use with Example_peerConnection. It does not return until the listner is
// active. // active.
@ -40,7 +47,11 @@ func mockRemotePeer() error {
// Create and start the inbound peer. // Create and start the inbound peer.
p := peer.NewInboundPeer(peerCfg) p := peer.NewInboundPeer(peerCfg)
p.AssociateConnection(conn) err = p.AssociateConnection(conn)
if err != nil {
fmt.Printf("AssociateConnection: error %+v\n", err)
return
}
}() }()
return nil return nil
@ -89,10 +100,14 @@ func Example_newOutboundPeer() {
// Establish the connection to the peer address and mark it connected. // Establish the connection to the peer address and mark it connected.
conn, err := net.Dial("tcp", p.Addr()) conn, err := net.Dial("tcp", p.Addr())
if err != nil { if err != nil {
fmt.Printf("net.Dial: error %v\n", err) fmt.Printf("net.Dial: error %+v\n", err)
return
}
err = p.AssociateConnection(conn)
if err != nil {
fmt.Printf("AssociateConnection: error %+v\n", err)
return return
} }
p.AssociateConnection(conn)
// Wait for the verack message or timeout in case of failure. // Wait for the verack message or timeout in case of failure.
select { select {

View File

@ -8,7 +8,6 @@ import (
"bytes" "bytes"
"container/list" "container/list"
"fmt" "fmt"
"github.com/pkg/errors"
"io" "io"
"math/rand" "math/rand"
"net" "net"
@ -17,6 +16,8 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/pkg/errors"
"github.com/kaspanet/kaspad/util/random" "github.com/kaspanet/kaspad/util/random"
"github.com/kaspanet/kaspad/util/subnetworkid" "github.com/kaspanet/kaspad/util/subnetworkid"
@ -1752,10 +1753,10 @@ func (p *Peer) QueueInventory(invVect *wire.InvVect) {
// AssociateConnection associates the given conn to the peer. Calling this // AssociateConnection associates the given conn to the peer. Calling this
// function when the peer is already connected will have no effect. // function when the peer is already connected will have no effect.
func (p *Peer) AssociateConnection(conn net.Conn) { func (p *Peer) AssociateConnection(conn net.Conn) error {
// Already connected? // Already connected?
if !atomic.CompareAndSwapInt32(&p.connected, 0, 1) { if !atomic.CompareAndSwapInt32(&p.connected, 0, 1) {
return return nil
} }
p.conn = conn p.conn = conn
@ -1769,19 +1770,18 @@ func (p *Peer) AssociateConnection(conn net.Conn) {
// and no point recomputing. // and no point recomputing.
na, err := newNetAddress(p.conn.RemoteAddr(), p.services) na, err := newNetAddress(p.conn.RemoteAddr(), p.services)
if err != nil { if err != nil {
log.Errorf("Cannot create remote net address: %s", err)
p.Disconnect() p.Disconnect()
return return errors.Wrap(err, "Cannot create remote net address")
} }
p.na = na p.na = na
} }
spawn(func() { if err := p.start(); err != nil {
if err := p.start(); err != nil { p.Disconnect()
log.Debugf("Cannot start peer %s: %s", p, err) return errors.Wrapf(err, "Cannot start peer %s", p)
p.Disconnect() }
}
}) return nil
} }
// Connected returns whether or not the peer is currently connected. // Connected returns whether or not the peer is currently connected.
@ -1841,6 +1841,7 @@ func (p *Peer) start() error {
// Send our verack message now that the IO processing machinery has started. // Send our verack message now that the IO processing machinery has started.
p.QueueMessage(wire.NewMsgVerAck(), nil) p.QueueMessage(wire.NewMsgVerAck(), nil)
return nil return nil
} }

View File

@ -2,20 +2,22 @@
// Use of this source code is governed by an ISC // Use of this source code is governed by an ISC
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
package peer_test package peer
import ( import (
"io" "io"
"net" "net"
"strconv" "strconv"
"strings"
"testing" "testing"
"time" "time"
"github.com/kaspanet/kaspad/util/testtools"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/btcsuite/go-socks/socks" "github.com/btcsuite/go-socks/socks"
"github.com/kaspanet/kaspad/dagconfig" "github.com/kaspanet/kaspad/dagconfig"
"github.com/kaspanet/kaspad/peer"
"github.com/kaspanet/kaspad/util/daghash" "github.com/kaspanet/kaspad/util/daghash"
"github.com/kaspanet/kaspad/wire" "github.com/kaspanet/kaspad/wire"
) )
@ -110,7 +112,7 @@ type peerStats struct {
} }
// testPeer tests the given peer's flags and stats // testPeer tests the given peer's flags and stats
func testPeer(t *testing.T, p *peer.Peer, s peerStats) { func testPeer(t *testing.T, p *Peer, s peerStats) {
if p.UserAgent() != s.wantUserAgent { if p.UserAgent() != s.wantUserAgent {
t.Errorf("testPeer: wrong UserAgent - got %v, want %v", p.UserAgent(), s.wantUserAgent) t.Errorf("testPeer: wrong UserAgent - got %v, want %v", p.UserAgent(), s.wantUserAgent)
return return
@ -199,16 +201,18 @@ func testPeer(t *testing.T, p *peer.Peer, s peerStats) {
// TestPeerConnection tests connection between inbound and outbound peers. // TestPeerConnection tests connection between inbound and outbound peers.
func TestPeerConnection(t *testing.T) { func TestPeerConnection(t *testing.T) {
verack := make(chan struct{}) inPeerVerack, outPeerVerack, inPeerOnWriteVerack, outPeerOnWriteVerack :=
peer1Cfg := &peer.Config{ make(chan struct{}, 1), make(chan struct{}, 1), make(chan struct{}, 1), make(chan struct{}, 1)
Listeners: peer.MessageListeners{
OnVerAck: func(p *peer.Peer, msg *wire.MsgVerAck) { inPeerCfg := &Config{
verack <- struct{}{} Listeners: MessageListeners{
OnVerAck: func(p *Peer, msg *wire.MsgVerAck) {
inPeerVerack <- struct{}{}
}, },
OnWrite: func(p *peer.Peer, bytesWritten int, msg wire.Message, OnWrite: func(p *Peer, bytesWritten int, msg wire.Message,
err error) { err error) {
if _, ok := msg.(*wire.MsgVerAck); ok { if _, ok := msg.(*wire.MsgVerAck); ok {
verack <- struct{}{} inPeerOnWriteVerack <- struct{}{}
} }
}, },
}, },
@ -220,8 +224,18 @@ func TestPeerConnection(t *testing.T) {
Services: 0, Services: 0,
SelectedTipHash: fakeSelectedTipFn, SelectedTipHash: fakeSelectedTipFn,
} }
peer2Cfg := &peer.Config{ outPeerCfg := &Config{
Listeners: peer1Cfg.Listeners, Listeners: MessageListeners{
OnVerAck: func(p *Peer, msg *wire.MsgVerAck) {
outPeerVerack <- struct{}{}
},
OnWrite: func(p *Peer, bytesWritten int, msg wire.Message,
err error) {
if _, ok := msg.(*wire.MsgVerAck); ok {
outPeerOnWriteVerack <- struct{}{}
}
},
},
UserAgentName: "peer", UserAgentName: "peer",
UserAgentVersion: "1.0", UserAgentVersion: "1.0",
UserAgentComments: []string{"comment"}, UserAgentComments: []string{"comment"},
@ -262,56 +276,42 @@ func TestPeerConnection(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
setup func() (*peer.Peer, *peer.Peer, error) setup func() (*Peer, *Peer, error)
}{ }{
{ {
"basic handshake", "basic handshake",
func() (*peer.Peer, *peer.Peer, error) { func() (*Peer, *Peer, error) {
inConn, outConn := pipe( inPeer, outPeer, err := setupPeers(inPeerCfg, outPeerCfg)
&conn{raddr: "10.0.0.1:16111"},
&conn{raddr: "10.0.0.2:16111"},
)
inPeer := peer.NewInboundPeer(peer1Cfg)
inPeer.AssociateConnection(inConn)
outPeer, err := peer.NewOutboundPeer(peer2Cfg, "10.0.0.2:16111")
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
outPeer.AssociateConnection(outConn)
for i := 0; i < 4; i++ { // wait for 4 veracks
select { if !testtools.WaitTillAllCompleteOrTimeout(time.Second,
case <-verack: inPeerVerack, inPeerOnWriteVerack, outPeerVerack, outPeerOnWriteVerack) {
case <-time.After(time.Second):
return nil, nil, errors.New("verack timeout") return nil, nil, errors.New("handshake timeout")
}
} }
return inPeer, outPeer, nil return inPeer, outPeer, nil
}, },
}, },
{ {
"socks proxy", "socks proxy",
func() (*peer.Peer, *peer.Peer, error) { func() (*Peer, *Peer, error) {
inConn, outConn := pipe( inConn, outConn := pipe(
&conn{raddr: "10.0.0.1:16111", proxy: true}, &conn{raddr: "10.0.0.1:16111", proxy: true},
&conn{raddr: "10.0.0.2:16111"}, &conn{raddr: "10.0.0.2:16111"},
) )
inPeer := peer.NewInboundPeer(peer1Cfg) inPeer, outPeer, err := setupPeersWithConns(inPeerCfg, outPeerCfg, inConn, outConn)
inPeer.AssociateConnection(inConn)
outPeer, err := peer.NewOutboundPeer(peer2Cfg, "10.0.0.2:16111")
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
outPeer.AssociateConnection(outConn)
for i := 0; i < 4; i++ { // wait for 4 veracks
select { if !testtools.WaitTillAllCompleteOrTimeout(time.Second,
case <-verack: inPeerVerack, inPeerOnWriteVerack, outPeerVerack, outPeerOnWriteVerack) {
case <-time.After(time.Second):
return nil, nil, errors.New("verack timeout") return nil, nil, errors.New("handshake timeout")
}
} }
return inPeer, outPeer, nil return inPeer, outPeer, nil
}, },
@ -336,62 +336,62 @@ func TestPeerConnection(t *testing.T) {
// TestPeerListeners tests that the peer listeners are called as expected. // TestPeerListeners tests that the peer listeners are called as expected.
func TestPeerListeners(t *testing.T) { func TestPeerListeners(t *testing.T) {
verack := make(chan struct{}, 1) inPeerVerack, outPeerVerack := make(chan struct{}, 1), make(chan struct{}, 1)
ok := make(chan wire.Message, 20) ok := make(chan wire.Message, 20)
peerCfg := &peer.Config{ inPeerCfg := &Config{
Listeners: peer.MessageListeners{ Listeners: MessageListeners{
OnGetAddr: func(p *peer.Peer, msg *wire.MsgGetAddr) { OnGetAddr: func(p *Peer, msg *wire.MsgGetAddr) {
ok <- msg ok <- msg
}, },
OnAddr: func(p *peer.Peer, msg *wire.MsgAddr) { OnAddr: func(p *Peer, msg *wire.MsgAddr) {
ok <- msg ok <- msg
}, },
OnPing: func(p *peer.Peer, msg *wire.MsgPing) { OnPing: func(p *Peer, msg *wire.MsgPing) {
ok <- msg ok <- msg
}, },
OnPong: func(p *peer.Peer, msg *wire.MsgPong) { OnPong: func(p *Peer, msg *wire.MsgPong) {
ok <- msg ok <- msg
}, },
OnTx: func(p *peer.Peer, msg *wire.MsgTx) { OnTx: func(p *Peer, msg *wire.MsgTx) {
ok <- msg ok <- msg
}, },
OnBlock: func(p *peer.Peer, msg *wire.MsgBlock, buf []byte) { OnBlock: func(p *Peer, msg *wire.MsgBlock, buf []byte) {
ok <- msg ok <- msg
}, },
OnInv: func(p *peer.Peer, msg *wire.MsgInv) { OnInv: func(p *Peer, msg *wire.MsgInv) {
ok <- msg ok <- msg
}, },
OnNotFound: func(p *peer.Peer, msg *wire.MsgNotFound) { OnNotFound: func(p *Peer, msg *wire.MsgNotFound) {
ok <- msg ok <- msg
}, },
OnGetData: func(p *peer.Peer, msg *wire.MsgGetData) { OnGetData: func(p *Peer, msg *wire.MsgGetData) {
ok <- msg ok <- msg
}, },
OnGetBlockInvs: func(p *peer.Peer, msg *wire.MsgGetBlockInvs) { OnGetBlockInvs: func(p *Peer, msg *wire.MsgGetBlockInvs) {
ok <- msg ok <- msg
}, },
OnFeeFilter: func(p *peer.Peer, msg *wire.MsgFeeFilter) { OnFeeFilter: func(p *Peer, msg *wire.MsgFeeFilter) {
ok <- msg ok <- msg
}, },
OnFilterAdd: func(p *peer.Peer, msg *wire.MsgFilterAdd) { OnFilterAdd: func(p *Peer, msg *wire.MsgFilterAdd) {
ok <- msg ok <- msg
}, },
OnFilterClear: func(p *peer.Peer, msg *wire.MsgFilterClear) { OnFilterClear: func(p *Peer, msg *wire.MsgFilterClear) {
ok <- msg ok <- msg
}, },
OnFilterLoad: func(p *peer.Peer, msg *wire.MsgFilterLoad) { OnFilterLoad: func(p *Peer, msg *wire.MsgFilterLoad) {
ok <- msg ok <- msg
}, },
OnMerkleBlock: func(p *peer.Peer, msg *wire.MsgMerkleBlock) { OnMerkleBlock: func(p *Peer, msg *wire.MsgMerkleBlock) {
ok <- msg ok <- msg
}, },
OnVersion: func(p *peer.Peer, msg *wire.MsgVersion) { OnVersion: func(p *Peer, msg *wire.MsgVersion) {
ok <- msg ok <- msg
}, },
OnVerAck: func(p *peer.Peer, msg *wire.MsgVerAck) { OnVerAck: func(p *Peer, msg *wire.MsgVerAck) {
verack <- struct{}{} inPeerVerack <- struct{}{}
}, },
OnReject: func(p *peer.Peer, msg *wire.MsgReject) { OnReject: func(p *Peer, msg *wire.MsgReject) {
ok <- msg ok <- msg
}, },
}, },
@ -402,32 +402,20 @@ func TestPeerListeners(t *testing.T) {
Services: wire.SFNodeBloom, Services: wire.SFNodeBloom,
SelectedTipHash: fakeSelectedTipFn, SelectedTipHash: fakeSelectedTipFn,
} }
inConn, outConn := pipe(
&conn{raddr: "10.0.0.1:16111"},
&conn{raddr: "10.0.0.2:16111"},
)
inPeer := peer.NewInboundPeer(peerCfg)
inPeer.AssociateConnection(inConn)
peerCfg.Listeners = peer.MessageListeners{ outPeerCfg := &Config{}
OnVerAck: func(p *peer.Peer, msg *wire.MsgVerAck) { *outPeerCfg = *inPeerCfg // copy inPeerCfg
verack <- struct{}{} outPeerCfg.Listeners.OnVerAck = func(p *Peer, msg *wire.MsgVerAck) {
}, outPeerVerack <- struct{}{}
} }
outPeer, err := peer.NewOutboundPeer(peerCfg, "10.0.0.1:16111")
inPeer, outPeer, err := setupPeers(inPeerCfg, outPeerCfg)
if err != nil { if err != nil {
t.Errorf("NewOutboundPeer: unexpected err %v\n", err) t.Errorf("TestPeerListeners: %v", err)
return
} }
outPeer.AssociateConnection(outConn) // wait for 2 veracks
if !testtools.WaitTillAllCompleteOrTimeout(time.Second, inPeerVerack, outPeerVerack) {
for i := 0; i < 2; i++ { t.Errorf("TestPeerListeners: Timout waiting for veracks")
select {
case <-verack:
case <-time.After(time.Second * 1):
t.Errorf("TestPeerListeners: verack timeout\n")
return
}
} }
tests := []struct { tests := []struct {
@ -520,7 +508,7 @@ func TestPeerListeners(t *testing.T) {
// TestOutboundPeer tests that the outbound peer works as expected. // TestOutboundPeer tests that the outbound peer works as expected.
func TestOutboundPeer(t *testing.T) { func TestOutboundPeer(t *testing.T) {
peerCfg := &peer.Config{ peerCfg := &Config{
SelectedTipHash: func() *daghash.Hash { SelectedTipHash: func() *daghash.Hash {
return &daghash.ZeroHash return &daghash.ZeroHash
}, },
@ -531,18 +519,16 @@ func TestOutboundPeer(t *testing.T) {
Services: 0, Services: 0,
} }
r, w := io.Pipe() _, p, err := setupPeers(peerCfg, peerCfg)
c := &conn{raddr: "10.0.0.1:16111", Writer: w, Reader: r}
p, err := peer.NewOutboundPeer(peerCfg, "10.0.0.1:16111")
if err != nil { if err != nil {
t.Errorf("NewOutboundPeer: unexpected err - %v\n", err) t.Fatalf("TestOuboundPeer: unexpected err in setupPeers - %v\n", err)
return
} }
// Test trying to connect twice. // Test trying to connect for a second time and make sure nothing happens.
p.AssociateConnection(c) err = p.AssociateConnection(p.conn)
p.AssociateConnection(c) if err != nil {
t.Fatalf("AssociateConnection for the second time didn't return nil")
}
p.Disconnect() p.Disconnect()
// Test Queue Inv // Test Queue Inv
@ -572,14 +558,11 @@ func TestOutboundPeer(t *testing.T) {
} }
peerCfg.SelectedTipHash = selectedTipHash peerCfg.SelectedTipHash = selectedTipHash
r1, w1 := io.Pipe()
c1 := &conn{raddr: "10.0.0.1:16111", Writer: w1, Reader: r1} _, p1, err := setupPeers(peerCfg, peerCfg)
p1, err := peer.NewOutboundPeer(peerCfg, "10.0.0.1:16111")
if err != nil { if err != nil {
t.Errorf("NewOutboundPeer: unexpected err - %v\n", err) t.Fatalf("TestOuboundPeer: unexpected err in setupPeers - %v\n", err)
return
} }
p1.AssociateConnection(c1)
// Test Queue Inv after connection // Test Queue Inv after connection
p1.QueueInventory(fakeInv) p1.QueueInventory(fakeInv)
@ -588,14 +571,10 @@ func TestOutboundPeer(t *testing.T) {
// Test regression // Test regression
peerCfg.DAGParams = &dagconfig.RegressionNetParams peerCfg.DAGParams = &dagconfig.RegressionNetParams
peerCfg.Services = wire.SFNodeBloom peerCfg.Services = wire.SFNodeBloom
r2, w2 := io.Pipe() _, p2, err := setupPeers(peerCfg, peerCfg)
c2 := &conn{raddr: "10.0.0.1:16111", Writer: w2, Reader: r2}
p2, err := peer.NewOutboundPeer(peerCfg, "10.0.0.1:16111")
if err != nil { if err != nil {
t.Errorf("NewOutboundPeer: unexpected err - %v\n", err) t.Fatalf("NewOutboundPeer: unexpected err - %v\n", err)
return
} }
p2.AssociateConnection(c2)
// Test PushXXX // Test PushXXX
var addrs []*wire.NetAddress var addrs []*wire.NetAddress
@ -604,12 +583,10 @@ func TestOutboundPeer(t *testing.T) {
addrs = append(addrs, &na) addrs = append(addrs, &na)
} }
if _, err := p2.PushAddrMsg(addrs, nil); err != nil { if _, err := p2.PushAddrMsg(addrs, nil); err != nil {
t.Errorf("PushAddrMsg: unexpected err %v\n", err) t.Fatalf("PushAddrMsg: unexpected err %v\n", err)
return
} }
if err := p2.PushGetBlockInvsMsg(nil, &daghash.Hash{}); err != nil { if err := p2.PushGetBlockInvsMsg(nil, &daghash.Hash{}); err != nil {
t.Errorf("PushGetBlockInvsMsg: unexpected err %v\n", err) t.Fatalf("PushGetBlockInvsMsg: unexpected err %v\n", err)
return
} }
p2.PushRejectMsg("block", wire.RejectMalformed, "malformed", nil, false) p2.PushRejectMsg("block", wire.RejectMalformed, "malformed", nil, false)
@ -627,7 +604,7 @@ func TestOutboundPeer(t *testing.T) {
// Tests that the node disconnects from peers with an unsupported protocol // Tests that the node disconnects from peers with an unsupported protocol
// version. // version.
func TestUnsupportedVersionPeer(t *testing.T) { func TestUnsupportedVersionPeer(t *testing.T) {
peerCfg := &peer.Config{ peerCfg := &Config{
UserAgentName: "peer", UserAgentName: "peer",
UserAgentVersion: "1.0", UserAgentVersion: "1.0",
UserAgentComments: []string{"comment"}, UserAgentComments: []string{"comment"},
@ -637,12 +614,12 @@ func TestUnsupportedVersionPeer(t *testing.T) {
} }
localNA := wire.NewNetAddressIPPort( localNA := wire.NewNetAddressIPPort(
net.ParseIP("10.0.0.1"), net.ParseIP("10.0.0.1:16111"),
uint16(16111), uint16(16111),
wire.SFNodeNetwork, wire.SFNodeNetwork,
) )
remoteNA := wire.NewNetAddressIPPort( remoteNA := wire.NewNetAddressIPPort(
net.ParseIP("10.0.0.2"), net.ParseIP("10.0.0.2:16111"),
uint16(16111), uint16(16111),
wire.SFNodeNetwork, wire.SFNodeNetwork,
) )
@ -651,11 +628,23 @@ func TestUnsupportedVersionPeer(t *testing.T) {
&conn{laddr: "10.0.0.2:16111", raddr: "10.0.0.1:16111"}, &conn{laddr: "10.0.0.2:16111", raddr: "10.0.0.1:16111"},
) )
p, err := peer.NewOutboundPeer(peerCfg, "10.0.0.1:16111") p, err := NewOutboundPeer(peerCfg, "10.0.0.1:16111")
if err != nil { if err != nil {
t.Fatalf("NewOutboundPeer: unexpected err - %v\n", err) t.Fatalf("NewOutboundPeer: unexpected err - %v\n", err)
} }
p.AssociateConnection(localConn)
go func() {
err := p.AssociateConnection(localConn)
wantErrorMessage := "protocol version must be 1 or greater"
if err == nil {
t.Fatalf("No error from AssociateConnection to invalid protocol version")
}
gotErrorMessage := err.Error()
if !strings.Contains(gotErrorMessage, wantErrorMessage) {
t.Fatalf("Wrong error message from AssociateConnection to invalid protocol version.\nWant: '%s'\nGot: '%s'",
wantErrorMessage, gotErrorMessage)
}
}()
// Read outbound messages to peer into a channel // Read outbound messages to peer into a channel
outboundMessages := make(chan wire.Message) outboundMessages := make(chan wire.Message)
@ -730,9 +719,56 @@ func TestUnsupportedVersionPeer(t *testing.T) {
func init() { func init() {
// Allow self connection when running the tests. // Allow self connection when running the tests.
peer.TstAllowSelfConns() TstAllowSelfConns()
} }
func fakeSelectedTipFn() *daghash.Hash { func fakeSelectedTipFn() *daghash.Hash {
return &daghash.Hash{0x12, 0x34} return &daghash.Hash{0x12, 0x34}
} }
func setupPeers(inPeerCfg, outPeerCfg *Config) (inPeer *Peer, outPeer *Peer, err error) {
inConn, outConn := pipe(
&conn{raddr: "10.0.0.1:16111"},
&conn{raddr: "10.0.0.2:16111"},
)
return setupPeersWithConns(inPeerCfg, outPeerCfg, inConn, outConn)
}
func setupPeersWithConns(inPeerCfg, outPeerCfg *Config, inConn, outConn *conn) (inPeer *Peer, outPeer *Peer, err error) {
inPeer = NewInboundPeer(inPeerCfg)
inPeerDone := make(chan struct{})
var inPeerErr error
go func() {
inPeerErr = inPeer.AssociateConnection(inConn)
inPeerDone <- struct{}{}
}()
outPeer, err = NewOutboundPeer(outPeerCfg, outConn.raddr)
if err != nil {
return nil, nil, err
}
outPeerDone := make(chan struct{})
var outPeerErr error
go func() {
outPeerErr = outPeer.AssociateConnection(outConn)
outPeerDone <- struct{}{}
}()
// wait for AssociateConnection to complete in all instances
if !testtools.WaitTillAllCompleteOrTimeout(2*time.Second, inPeerDone, outPeerDone) {
return nil, nil, errors.New("handshake timeout")
}
if inPeerErr != nil && outPeerErr != nil {
return nil, nil, errors.Errorf("both inPeer and outPeer failed connecting: \nInPeer: %+v\nOutPeer: %+v",
inPeerErr, outPeerErr)
}
if inPeerErr != nil {
return nil, nil, inPeerErr
}
if outPeerErr != nil {
return nil, nil, outPeerErr
}
return inPeer, outPeer, nil
}

View File

@ -11,9 +11,6 @@ import (
// and is used to negotiate the protocol version details as well as kick start // and is used to negotiate the protocol version details as well as kick start
// the communications. // the communications.
func (sp *Peer) OnVersion(_ *peer.Peer, msg *wire.MsgVersion) { func (sp *Peer) OnVersion(_ *peer.Peer, msg *wire.MsgVersion) {
// Signal the sync manager this peer is a new sync candidate.
sp.server.SyncManager.NewPeer(sp.Peer)
// Choose whether or not to relay transactions before a filter command // Choose whether or not to relay transactions before a filter command
// is received. // is received.
sp.setDisableRelayTx(msg.DisableRelayTx) sp.setDisableRelayTx(msg.DisableRelayTx)
@ -54,7 +51,4 @@ func (sp *Peer) OnVersion(_ *peer.Peer, msg *wire.MsgVersion) {
addrManager.Good(sp.NA(), msg.SubnetworkID) addrManager.Good(sp.NA(), msg.SubnetworkID)
} }
} }
// Add valid peer to the server.
sp.server.AddPeer(sp)
} }

View File

@ -967,12 +967,9 @@ func newPeerConfig(sp *Peer) *peer.Config {
// for disconnection. // for disconnection.
func (s *Server) inboundPeerConnected(conn net.Conn) { func (s *Server) inboundPeerConnected(conn net.Conn) {
sp := newServerPeer(s, false) sp := newServerPeer(s, false)
sp.isWhitelisted = isWhitelisted(conn.RemoteAddr())
sp.Peer = peer.NewInboundPeer(newPeerConfig(sp)) sp.Peer = peer.NewInboundPeer(newPeerConfig(sp))
sp.AssociateConnection(conn)
spawn(func() { s.peerConnected(sp, conn)
s.peerDoneHandler(sp)
})
} }
// outboundPeerConnected is invoked by the connection manager when a new // outboundPeerConnected is invoked by the connection manager when a new
@ -989,12 +986,28 @@ func (s *Server) outboundPeerConnected(state *peerState, msg *outboundPeerConnec
} }
sp.Peer = outboundPeer sp.Peer = outboundPeer
sp.connReq = msg.connReq sp.connReq = msg.connReq
sp.isWhitelisted = isWhitelisted(msg.conn.RemoteAddr())
sp.AssociateConnection(msg.conn) s.peerConnected(sp, msg.conn)
s.addrManager.Attempt(sp.NA())
}
func (s *Server) peerConnected(sp *Peer, conn net.Conn) {
sp.isWhitelisted = isWhitelisted(conn.RemoteAddr())
spawn(func() { spawn(func() {
err := sp.AssociateConnection(conn)
if err != nil {
peerLog.Debugf("Error connecting to peer: %+v", err)
return
}
s.SyncManager.NewPeer(sp.Peer)
s.AddPeer(sp)
s.peerDoneHandler(sp) s.peerDoneHandler(sp)
}) })
s.addrManager.Attempt(sp.NA())
} }
// outboundPeerConnected is invoked by the connection manager when a new // outboundPeerConnected is invoked by the connection manager when a new

View File

@ -1,6 +1,8 @@
package testtools package testtools
import ( import (
"time"
"github.com/kaspanet/kaspad/dagconfig" "github.com/kaspanet/kaspad/dagconfig"
"github.com/pkg/errors" "github.com/pkg/errors"
@ -94,3 +96,21 @@ func RegisterSubnetworkForTest(dag *blockdag.BlockDAG, params *dagconfig.Params,
} }
return subnetworkID, nil return subnetworkID, nil
} }
// WaitTillAllCompleteOrTimeout waits until all the provided channels has been written to,
// or until a timeout period has passed.
// Returns true iff all channels returned in the allotted time.
func WaitTillAllCompleteOrTimeout(timeoutDuration time.Duration, chans ...chan struct{}) (ok bool) {
timeout := time.After(timeoutDuration)
for _, c := range chans {
select {
case <-c:
continue
case <-timeout:
return false
}
}
return true
}