diff --git a/addrmgr/addrmanager.go b/addrmgr/addrmanager.go index 0dcff7b37..4d553ee73 100644 --- a/addrmgr/addrmanager.go +++ b/addrmgr/addrmanager.go @@ -708,27 +708,32 @@ func (a *AddrManager) NeedMoreAddresses() bool { a.mtx.Lock() defer a.mtx.Unlock() - return a.numAddresses(a.localSubnetworkID)+a.numAddresses(&wire.SubnetworkIDUnknown) < needAddressThreshold + allAddrs := a.numAddresses(a.localSubnetworkID) + a.numAddresses(&wire.SubnetworkIDUnknown) + if !a.localSubnetworkID.IsEqual(&wire.SubnetworkIDSupportsAll) { + allAddrs += a.numAddresses(&wire.SubnetworkIDSupportsAll) + } + return allAddrs < needAddressThreshold } // AddressCache returns the current address cache. It must be treated as // read-only (but since it is a copy now, this is not as dangerous). -func (a *AddrManager) AddressCache() []*wire.NetAddress { +func (a *AddrManager) AddressCache(subnetworkID *subnetworkid.SubnetworkID) []*wire.NetAddress { a.mtx.Lock() defer a.mtx.Unlock() - addrIndexLen := len(a.addrIndex) - if addrIndexLen == 0 { + if len(a.addrIndex) == 0 { return nil } - allAddr := make([]*wire.NetAddress, 0, addrIndexLen) + allAddr := []*wire.NetAddress{} // Iteration order is undefined here, but we randomise it anyway. for _, v := range a.addrIndex { - allAddr = append(allAddr, v.na) + if subnetworkID == nil || v.SubnetworkID().IsEqual(subnetworkID) { + allAddr = append(allAddr, v.na) + } } - numAddresses := addrIndexLen * getAddrPercent / 100 + numAddresses := len(allAddr) * getAddrPercent / 100 if numAddresses > getAddrMax { numAddresses = getAddrMax } @@ -737,7 +742,7 @@ func (a *AddrManager) AddressCache() []*wire.NetAddress { // `numAddresses' since we are throwing the rest. for i := 0; i < numAddresses; i++ { // pick a number between current index and the end - j := rand.Intn(addrIndexLen-i) + i + j := rand.Intn(len(allAddr)-i) + i allAddr[i], allAddr[j] = allAddr[j], allAddr[i] } diff --git a/addrmgr/addrmanager_test.go b/addrmgr/addrmanager_test.go index 42242ef01..6c417fdcf 100644 --- a/addrmgr/addrmanager_test.go +++ b/addrmgr/addrmanager_test.go @@ -288,6 +288,8 @@ func TestGood(t *testing.T) { n := addrmgr.New("testgood", lookupFunc, &wire.SubnetworkIDSupportsAll) addrsToAdd := 64 * 64 addrs := make([]*wire.NetAddress, addrsToAdd) + subnetworkCount := 32 + subnetworkIDs := make([]*subnetworkid.SubnetworkID, subnetworkCount) var err error for i := 0; i < addrsToAdd; i++ { @@ -298,11 +300,15 @@ func TestGood(t *testing.T) { } } + for i := 0; i < subnetworkCount; i++ { + subnetworkIDs[i] = &subnetworkid.SubnetworkID{0xff - byte(i)} + } + srcAddr := wire.NewNetAddressIPPort(net.IPv4(173, 144, 173, 111), 8333, 0) n.AddAddresses(addrs, srcAddr) - for _, addr := range addrs { - n.Good(addr, &wire.SubnetworkIDSupportsAll) + for i, addr := range addrs { + n.Good(addr, subnetworkIDs[i%subnetworkCount]) } numAddrs := n.TotalNumAddresses() @@ -310,9 +316,18 @@ func TestGood(t *testing.T) { t.Errorf("Number of addresses is too many: %d vs %d", numAddrs, addrsToAdd) } - numCache := len(n.AddressCache()) - if numCache >= numAddrs/4 { - t.Errorf("Number of addresses in cache: got %d, want %d", numCache, numAddrs/4) + numCache := len(n.AddressCache(nil)) + if numCache == 0 || numCache >= numAddrs/4 { + t.Errorf("Number of addresses in cache: got %d, want positive and less than %d", + numCache, numAddrs/4) + } + + for i := 0; i < subnetworkCount; i++ { + numCache = len(n.AddressCache(subnetworkIDs[i])) + if numCache == 0 || numCache >= numAddrs/subnetworkCount { + t.Errorf("Number of addresses in subnetwork cache: got %d, want positive and less than %d", + numCache, numAddrs/4/subnetworkCount) + } } } diff --git a/config/config.go b/config/config.go index d389d0660..7d63d2f8c 100644 --- a/config/config.go +++ b/config/config.go @@ -166,7 +166,7 @@ type configFlags struct { DropAddrIndex bool `long:"dropaddrindex" description:"Deletes the address-based transaction index from the database on start up and then exits."` RelayNonStd bool `long:"relaynonstd" description:"Relay non-standard transactions regardless of the default settings for the active network."` RejectNonStd bool `long:"rejectnonstd" description:"Reject non-standard transactions regardless of the default settings for the active network."` - Subnetwork string `string:"subnetwork" description:"If subnetwork != 0, than node will request and process only payloads from specified subnetwork. And if subnetwork is 0, than payloads of all subnetworks are processed. Subnetworks 3 through 255 are reserved for future use and are currently not allowed."` + Subnetwork string `long:"subnetwork" description:"If subnetwork ID != 0, than node will request and process only payloads from specified subnetwork. And if subnetwork ID is 0, than payloads of all subnetworks are processed. Subnetworks with IDs 3 through 255 are reserved for future use and are currently not allowed."` } // Config defines the configuration options for btcd. @@ -758,6 +758,8 @@ func loadConfig() (*Config, []string, error) { if err != nil { return nil, nil, err } + } else { + cfg.SubnetworkID = &wire.SubnetworkIDSupportsAll } // Check that 'generate' and 'subnetwork' flags do not conflict diff --git a/connmgr/seed.go b/connmgr/seed.go index 7569679b7..a28bca3e4 100644 --- a/connmgr/seed.go +++ b/connmgr/seed.go @@ -11,6 +11,8 @@ import ( "strconv" "time" + "github.com/daglabs/btcd/util/subnetworkid" + "github.com/daglabs/btcd/dagconfig" "github.com/daglabs/btcd/wire" ) @@ -30,7 +32,7 @@ type OnSeed func(addrs []*wire.NetAddress) type LookupFunc func(string) ([]net.IP, error) // SeedFromDNS uses DNS seeding to populate the address manager with peers. -func SeedFromDNS(dagParams *dagconfig.Params, reqServices wire.ServiceFlag, +func SeedFromDNS(dagParams *dagconfig.Params, reqServices wire.ServiceFlag, subnetworkID *subnetworkid.SubnetworkID, lookupFn LookupFunc, seedFn OnSeed) { for _, dnsseed := range dagParams.DNSSeeds { @@ -41,6 +43,10 @@ func SeedFromDNS(dagParams *dagconfig.Params, reqServices wire.ServiceFlag, host = fmt.Sprintf("x%x.%s", uint64(reqServices), dnsseed.Host) } + if !subnetworkID.IsEqual(&wire.SubnetworkIDSupportsAll) { + host = fmt.Sprintf("n%s.%s", subnetworkID, host) + } + go func(host string) { randSource := mrand.New(mrand.NewSource(time.Now().UnixNano())) diff --git a/dagconfig/params.go b/dagconfig/params.go index ec02a2d74..daf7edd0d 100644 --- a/dagconfig/params.go +++ b/dagconfig/params.go @@ -340,7 +340,12 @@ var TestNet3Params = Params{ Net: wire.TestNet3, RPCPort: "18334", DefaultPort: "18333", - DNSSeeds: []DNSSeed{}, + DNSSeeds: []DNSSeed{ + {"testnet-seed.alexykot.me", false}, + {"testnet-seed.bitcoin.petertodd.org", false}, + {"testnet-seed.bluematt.me", false}, + {"testnet-seed.bitcoin.schildbach.de", false}, + }, // Chain parameters GenesisBlock: &testNet3GenesisBlock, diff --git a/peer/log.go b/peer/log.go index 0b9190427..16be368bc 100644 --- a/peer/log.go +++ b/peer/log.go @@ -131,7 +131,7 @@ func messageSummary(msg wire.Message) string { // No summary. case *wire.MsgGetAddr: - // No summary. + return fmt.Sprintf("subnetwork ID %v", msg.SubnetworkID) case *wire.MsgAddr: return fmt.Sprintf("%d addr", len(msg.AddrList)) diff --git a/peer/peer.go b/peer/peer.go index 952d23ea3..2ea337efc 100644 --- a/peer/peer.go +++ b/peer/peer.go @@ -874,7 +874,7 @@ func (p *Peer) localVersionMsg() (*wire.MsgVersion, error) { // message will be sent if there are no entries in the provided addresses slice. // // This function is safe for concurrent access. -func (p *Peer) PushAddrMsg(addresses []*wire.NetAddress) ([]*wire.NetAddress, error) { +func (p *Peer) PushAddrMsg(addresses []*wire.NetAddress, subnetworkID *subnetworkid.SubnetworkID) ([]*wire.NetAddress, error) { addressCount := len(addresses) // Nothing to send. @@ -882,7 +882,7 @@ func (p *Peer) PushAddrMsg(addresses []*wire.NetAddress) ([]*wire.NetAddress, er return nil, nil } - msg := wire.NewMsgAddr() + msg := wire.NewMsgAddr(subnetworkID) msg.AddrList = make([]*wire.NetAddress, addressCount) copy(msg.AddrList, addresses) @@ -1048,7 +1048,7 @@ func (p *Peer) handleRemoteVersionMsg(msg *wire.MsgVersion) error { isLocalNodeFull := p.cfg.SubnetworkID.IsEqual(&wire.SubnetworkIDSupportsAll) isRemoteNodeFull := msg.SubnetworkID.IsEqual(&wire.SubnetworkIDSupportsAll) if (isLocalNodeFull && !isRemoteNodeFull && !p.inbound) || - (!isRemoteNodeFull && !msg.SubnetworkID.IsEqual(p.cfg.SubnetworkID)) { + (!isLocalNodeFull && !isRemoteNodeFull && !msg.SubnetworkID.IsEqual(p.cfg.SubnetworkID)) { return errors.New("incompatible subnetworks") } diff --git a/peer/peer_test.go b/peer/peer_test.go index c6f04f705..1eec3c3b3 100644 --- a/peer/peer_test.go +++ b/peer/peer_test.go @@ -477,11 +477,11 @@ func TestPeerListeners(t *testing.T) { }{ { "OnGetAddr", - wire.NewMsgGetAddr(), + wire.NewMsgGetAddr(nil), }, { "OnAddr", - wire.NewMsgAddr(), + wire.NewMsgAddr(nil), }, { "OnPing", @@ -717,7 +717,7 @@ func TestOutboundPeer(t *testing.T) { na := wire.NetAddress{} addrs = append(addrs, &na) } - if _, err := p2.PushAddrMsg(addrs); err != nil { + if _, err := p2.PushAddrMsg(addrs, nil); err != nil { t.Errorf("PushAddrMsg: unexpected err %v\n", err) return } @@ -734,7 +734,7 @@ func TestOutboundPeer(t *testing.T) { p2.PushRejectMsg("block", wire.RejectInvalid, "invalid", nil, false) // Test Queue Messages - p2.QueueMessage(wire.NewMsgGetAddr(), nil) + p2.QueueMessage(wire.NewMsgGetAddr(nil), nil) p2.QueueMessage(wire.NewMsgPing(1), nil) p2.QueueMessage(wire.NewMsgMemPool(), nil) p2.QueueMessage(wire.NewMsgGetData(), nil) diff --git a/server/p2p/p2p.go b/server/p2p/p2p.go index 34c8cdc04..161a50bb0 100644 --- a/server/p2p/p2p.go +++ b/server/p2p/p2p.go @@ -21,6 +21,8 @@ import ( "sync/atomic" "time" + "github.com/daglabs/btcd/util/subnetworkid" + "github.com/daglabs/btcd/addrmgr" "github.com/daglabs/btcd/blockdag" "github.com/daglabs/btcd/blockdag/indexers" @@ -326,7 +328,7 @@ func (sp *Peer) relayTxDisabled() bool { // pushAddrMsg sends an addr message to the connected peer using the provided // addresses. -func (sp *Peer) pushAddrMsg(addresses []*wire.NetAddress) { +func (sp *Peer) pushAddrMsg(addresses []*wire.NetAddress, subnetworkID *subnetworkid.SubnetworkID) { // Filter addresses already known to the peer. addrs := make([]*wire.NetAddress, 0, len(addresses)) for _, addr := range addresses { @@ -334,7 +336,7 @@ func (sp *Peer) pushAddrMsg(addresses []*wire.NetAddress) { addrs = append(addrs, addr) } } - known, err := sp.PushAddrMsg(addrs) + known, err := sp.PushAddrMsg(addrs, subnetworkID) if err != nil { peerLog.Errorf("Can't push address message to %s: %s", sp.Peer, err) sp.Disconnect() @@ -415,14 +417,18 @@ func (sp *Peer) OnVersion(_ *peer.Peer, msg *wire.MsgVersion) { if addrmgr.IsRoutable(lna) { // Filter addresses the peer already knows about. addresses := []*wire.NetAddress{lna} - sp.pushAddrMsg(addresses) + sp.pushAddrMsg(addresses, sp.SubnetworkID()) } } // Request known addresses if the server address manager needs // more. if addrManager.NeedMoreAddresses() { - sp.QueueMessage(wire.NewMsgGetAddr(), nil) + sp.QueueMessage(wire.NewMsgGetAddr(sp.SubnetworkID()), nil) + + if !sp.SubnetworkID().IsEqual(&wire.SubnetworkIDSupportsAll) { + sp.QueueMessage(wire.NewMsgGetAddr(&wire.SubnetworkIDSupportsAll), nil) + } } // Mark the address as a known good address. @@ -1091,10 +1097,10 @@ func (sp *Peer) OnGetAddr(_ *peer.Peer, msg *wire.MsgGetAddr) { sp.sentAddrs = true // Get the current known addresses from the address manager. - addrCache := sp.server.addrManager.AddressCache() + addrCache := sp.server.addrManager.AddressCache(msg.SubnetworkID) // Push the addresses. - sp.pushAddrMsg(addrCache) + sp.pushAddrMsg(addrCache, sp.SubnetworkID()) } // OnAddr is invoked when a peer receives an addr bitcoin message and is @@ -1893,16 +1899,22 @@ func (s *Server) peerHandler() { } if !config.MainConfig().DisableDNSSeed { - // Add peers discovered through DNS to the address manager. + seedFn := func(addrs []*wire.NetAddress) { + // Bitcoind uses a lookup of the dns seeder here. Since seeder returns + // IPs of nodes and not its own IP, we can not know real IP of + // source. So we'll take first returned address as source. + s.addrManager.AddAddresses(addrs, addrs[0]) + } + + // Add full nodes discovered through DNS to the address manager. connmgr.SeedFromDNS(config.ActiveNetParams(), defaultRequiredServices, - serverutils.BTCDLookup, func(addrs []*wire.NetAddress) { - // Bitcoind uses a lookup of the dns seeder here. This - // is rather strange since the values looked up by the - // DNS seed lookups will vary quite a lot. - // to replicate this behaviour we put all addresses as - // having come from the first one. - s.addrManager.AddAddresses(addrs, addrs[0]) - }) + &wire.SubnetworkIDSupportsAll, serverutils.BTCDLookup, seedFn) + + if !config.MainConfig().SubnetworkID.IsEqual(&wire.SubnetworkIDSupportsAll) { + // Node is partial - fetch nodes with same subnetwork + connmgr.SeedFromDNS(config.ActiveNetParams(), defaultRequiredServices, + config.MainConfig().SubnetworkID, serverutils.BTCDLookup, seedFn) + } } go s.connManager.Start() @@ -2217,7 +2229,14 @@ func ParseListeners(addrs []string) ([]net.Addr, error) { // Parse the IP. ip := net.ParseIP(host) if ip == nil { - return nil, fmt.Errorf("'%s' is not a valid IP address", host) + hostAddrs, err := net.LookupHost(host) + if err != nil { + return nil, err + } + ip = net.ParseIP(hostAddrs[0]) + if ip == nil { + return nil, fmt.Errorf("Cannot resolve IP address for host '%s'", host) + } } // To4 returns nil when the IP is not an IPv4 address, so use diff --git a/wire/bench_test.go b/wire/bench_test.go index 83fbb1160..0b181da2d 100644 --- a/wire/bench_test.go +++ b/wire/bench_test.go @@ -484,7 +484,7 @@ func BenchmarkDecodeAddr(b *testing.B) { // Create a message with the maximum number of addresses. pver := ProtocolVersion ip := net.ParseIP("127.0.0.1") - ma := NewMsgAddr() + ma := NewMsgAddr(nil) for port := uint16(0); port < MaxAddrPerMsg; port++ { ma.AddAddress(NewNetAddressIPPort(ip, port, SFNodeNetwork)) } diff --git a/wire/common.go b/wire/common.go index e84b381d6..939a36275 100644 --- a/wire/common.go +++ b/wire/common.go @@ -13,6 +13,7 @@ import ( "time" "github.com/daglabs/btcd/dagconfig/daghash" + "github.com/daglabs/btcd/util/subnetworkid" ) const ( @@ -271,6 +272,13 @@ func readElement(r io.Reader, element interface{}) error { } return nil + case *subnetworkid.SubnetworkID: + _, err := io.ReadFull(r, e[:]) + if err != nil { + return err + } + return nil + case *ServiceFlag: rv, err := binarySerializer.Uint64(r, littleEndian) if err != nil { @@ -405,6 +413,13 @@ func writeElement(w io.Writer, element interface{}) error { } return nil + case *subnetworkid.SubnetworkID: + _, err := w.Write(e[:]) + if err != nil { + return err + } + return nil + case ServiceFlag: err := binarySerializer.PutUint64(w, littleEndian, uint64(e)) if err != nil { diff --git a/wire/message_test.go b/wire/message_test.go index 7d17429aa..61550777f 100644 --- a/wire/message_test.go +++ b/wire/message_test.go @@ -49,8 +49,8 @@ func TestMessage(t *testing.T) { msgVersion := NewMsgVersion(me, you, 123123, 0, &SubnetworkIDSupportsAll) msgVerack := NewMsgVerAck() - msgGetAddr := NewMsgGetAddr() - msgAddr := NewMsgAddr() + msgGetAddr := NewMsgGetAddr(nil) + msgAddr := NewMsgAddr(nil) msgGetBlocks := NewMsgGetBlocks(&daghash.Hash{}) msgBlock := &blockOne msgInv := NewMsgInv() @@ -88,8 +88,8 @@ func TestMessage(t *testing.T) { }{ {msgVersion, msgVersion, pver, MainNet, 145}, {msgVerack, msgVerack, pver, MainNet, 24}, - {msgGetAddr, msgGetAddr, pver, MainNet, 24}, - {msgAddr, msgAddr, pver, MainNet, 25}, + {msgGetAddr, msgGetAddr, pver, MainNet, 25}, + {msgAddr, msgAddr, pver, MainNet, 26}, {msgGetBlocks, msgGetBlocks, pver, MainNet, 61}, {msgBlock, msgBlock, pver, MainNet, 340}, {msgInv, msgInv, pver, MainNet, 25}, @@ -221,7 +221,7 @@ func TestReadMessageWireErrors(t *testing.T) { // Wire encoded bytes for a message which exceeds the max payload for // a specific message type. - exceedTypePayloadBytes := makeHeader(btcnet, "getaddr", 1, 0) + exceedTypePayloadBytes := makeHeader(btcnet, "getaddr", 22, 0) // Wire encoded bytes for a message which does not deliver the full // payload according to the header length. diff --git a/wire/msgaddr.go b/wire/msgaddr.go index 566d7c5a5..0cd24f9fe 100644 --- a/wire/msgaddr.go +++ b/wire/msgaddr.go @@ -7,6 +7,8 @@ package wire import ( "fmt" "io" + + "github.com/daglabs/btcd/util/subnetworkid" ) // MaxAddrPerMsg is the maximum number of addresses that can be in a single @@ -24,7 +26,8 @@ const MaxAddrPerMsg = 1000 // Use the AddAddress function to build up the list of known addresses when // sending an addr message to another peer. type MsgAddr struct { - AddrList []*NetAddress + SubnetworkID *subnetworkid.SubnetworkID + AddrList []*NetAddress } // AddAddress adds a known active peer to the message. @@ -58,6 +61,24 @@ func (msg *MsgAddr) ClearAddresses() { // BtcDecode decodes r using the bitcoin protocol encoding into the receiver. // This is part of the Message interface implementation. func (msg *MsgAddr) BtcDecode(r io.Reader, pver uint32) error { + // Read subnetwork + var isAllSubnetworks bool + err := readElement(r, &isAllSubnetworks) + if err != nil { + return err + } + if isAllSubnetworks { + msg.SubnetworkID = nil + } else { + var subnetworkID subnetworkid.SubnetworkID + err = readElement(r, &subnetworkID) + if err != nil { + return err + } + msg.SubnetworkID = &subnetworkID + } + + // Read addresses array count, err := ReadVarInt(r, pver) if err != nil { return err @@ -74,7 +95,7 @@ func (msg *MsgAddr) BtcDecode(r io.Reader, pver uint32) error { msg.AddrList = make([]*NetAddress, 0, count) for i := uint64(0); i < count; i++ { na := &addrList[i] - err := readNetAddress(r, pver, na, true) + err = readNetAddress(r, pver, na, true) if err != nil { return err } @@ -86,8 +107,6 @@ func (msg *MsgAddr) BtcDecode(r io.Reader, pver uint32) error { // BtcEncode encodes the receiver to w using the bitcoin protocol encoding. // This is part of the Message interface implementation. func (msg *MsgAddr) BtcEncode(w io.Writer, pver uint32) error { - // Protocol versions before MultipleAddressVersion only allowed 1 address - // per message. count := len(msg.AddrList) if count > MaxAddrPerMsg { str := fmt.Sprintf("too many addresses for message "+ @@ -95,7 +114,20 @@ func (msg *MsgAddr) BtcEncode(w io.Writer, pver uint32) error { return messageError("MsgAddr.BtcEncode", str) } - err := WriteVarInt(w, pver, uint64(count)) + // Write subnetwork ID + isAllSubnetworks := msg.SubnetworkID == nil + err := writeElement(w, isAllSubnetworks) + if err != nil { + return err + } + if !isAllSubnetworks { + err = writeElement(w, msg.SubnetworkID) + if err != nil { + return err + } + } + + err = WriteVarInt(w, pver, uint64(count)) if err != nil { return err } @@ -119,14 +151,15 @@ func (msg *MsgAddr) Command() string { // MaxPayloadLength returns the maximum length the payload can be for the // receiver. This is part of the Message interface implementation. func (msg *MsgAddr) MaxPayloadLength(pver uint32) uint32 { - // Num addresses (varInt) + max allowed addresses. - return MaxVarIntPayload + (MaxAddrPerMsg * maxNetAddressPayload(pver)) + // IsAllSubnetworks flag 1 byte + SubnetworkID length + Num addresses (varInt) + max allowed addresses. + return 1 + subnetworkid.IDLength + MaxVarIntPayload + (MaxAddrPerMsg * maxNetAddressPayload(pver)) } // NewMsgAddr returns a new bitcoin addr message that conforms to the // Message interface. See MsgAddr for details. -func NewMsgAddr() *MsgAddr { +func NewMsgAddr(subnetworkID *subnetworkid.SubnetworkID) *MsgAddr { return &MsgAddr{ - AddrList: make([]*NetAddress, 0, MaxAddrPerMsg), + SubnetworkID: subnetworkID, + AddrList: make([]*NetAddress, 0, MaxAddrPerMsg), } } diff --git a/wire/msgaddr_test.go b/wire/msgaddr_test.go index d8db3a3b2..a5df7a45c 100644 --- a/wire/msgaddr_test.go +++ b/wire/msgaddr_test.go @@ -21,7 +21,7 @@ func TestAddr(t *testing.T) { // Ensure the command is expected value. wantCmd := "addr" - msg := NewMsgAddr() + msg := NewMsgAddr(nil) if cmd := msg.Command(); cmd != wantCmd { t.Errorf("NewMsgAddr: wrong command - got %v want %v", cmd, wantCmd) @@ -29,7 +29,7 @@ func TestAddr(t *testing.T) { // Ensure max payload is expected value for latest protocol version. // Num addresses (varInt) + max allowed addresses. - wantPayload := uint32(34009) + wantPayload := uint32(34030) maxPayload := msg.MaxPayloadLength(pver) if maxPayload != wantPayload { t.Errorf("MaxPayloadLength: wrong max payload length for "+ @@ -91,15 +91,17 @@ func TestAddrWire(t *testing.T) { } // Empty address message. - noAddr := NewMsgAddr() + noAddr := NewMsgAddr(nil) noAddrEncoded := []byte{ + 0x01, // All subnetworks 0x00, // Varint for number of addresses } // Address message with multiple addresses. - multiAddr := NewMsgAddr() + multiAddr := NewMsgAddr(nil) multiAddr.AddAddresses(na, na2) multiAddrEncoded := []byte{ + 0x01, // All subnetworks 0x02, // Varint for number of addresses 0x29, 0xab, 0x5f, 0x49, 0x00, 0x00, 0x00, 0x00, // Timestamp 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // SFNodeNetwork @@ -111,7 +113,27 @@ func TestAddrWire(t *testing.T) { 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xc0, 0xa8, 0x00, 0x01, // IP 192.168.0.1 0x20, 0x8e, // Port 8334 in big-endian + } + // Address message with multiple addresses and subnetworkID. + multiAddrSubnet := NewMsgAddr(&SubnetworkIDNative) + multiAddrSubnet.AddAddresses(na, na2) + multiAddrSubnetEncoded := []byte{ + 0x00, // All subnetworks + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // Subnetwork ID + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x02, // Varint for number of addresses + 0x29, 0xab, 0x5f, 0x49, 0x00, 0x00, 0x00, 0x00, // Timestamp + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // SFNodeNetwork + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0xff, 0xff, 0x7f, 0x00, 0x00, 0x01, // IP 127.0.0.1 + 0x20, 0x8d, // Port 8333 in big-endian + 0x29, 0xab, 0x5f, 0x49, 0x00, 0x00, 0x00, 0x00, // Timestamp + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // SFNodeNetwork + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0xff, 0xff, 0xc0, 0xa8, 0x00, 0x01, // IP 192.168.0.1 + 0x20, 0x8e, // Port 8334 in big-endian } tests := []struct { @@ -135,6 +157,14 @@ func TestAddrWire(t *testing.T) { multiAddrEncoded, ProtocolVersion, }, + + // Latest protocol version with multiple addresses and subnetwork. + { + multiAddrSubnet, + multiAddrSubnet, + multiAddrSubnetEncoded, + ProtocolVersion, + }, } t.Logf("Running %d tests", len(tests)) @@ -189,9 +219,10 @@ func TestAddrWireErrors(t *testing.T) { } // Address message with multiple addresses. - baseAddr := NewMsgAddr() + baseAddr := NewMsgAddr(nil) baseAddr.AddAddresses(na, na2) baseAddrEncoded := []byte{ + 0x01, // All subnetworks 0x02, // Varint for number of addresses 0x29, 0xab, 0x5f, 0x49, 0x00, 0x00, 0x00, 0x00, // Timestamp 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // SFNodeNetwork @@ -203,17 +234,17 @@ func TestAddrWireErrors(t *testing.T) { 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xc0, 0xa8, 0x00, 0x01, // IP 192.168.0.1 0x20, 0x8e, // Port 8334 in big-endian - } // Message that forces an error by having more than the max allowed // addresses. - maxAddr := NewMsgAddr() + maxAddr := NewMsgAddr(nil) for i := 0; i < MaxAddrPerMsg; i++ { maxAddr.AddAddress(na) } maxAddr.AddrList = append(maxAddr.AddrList, na) maxAddrEncoded := []byte{ + 0x01, // All subnetworks 0xfd, 0x03, 0xe9, // Varint for number of addresses (1001) } @@ -227,11 +258,11 @@ func TestAddrWireErrors(t *testing.T) { }{ // Latest protocol version with intentional read/write errors. // Force error in addresses count - {baseAddr, baseAddrEncoded, pver, 0, io.ErrShortWrite, io.EOF}, - // Force error in address list. {baseAddr, baseAddrEncoded, pver, 1, io.ErrShortWrite, io.EOF}, + // Force error in address list. + {baseAddr, baseAddrEncoded, pver, 2, io.ErrShortWrite, io.EOF}, // Force error with greater than max inventory vectors. - {maxAddr, maxAddrEncoded, pver, 3, wireErr, wireErr}, + {maxAddr, maxAddrEncoded, pver, 4, wireErr, wireErr}, } t.Logf("Running %d tests", len(tests)) diff --git a/wire/msggetaddr.go b/wire/msggetaddr.go index 0a4bf57ad..b4a71a9c6 100644 --- a/wire/msggetaddr.go +++ b/wire/msggetaddr.go @@ -6,6 +6,8 @@ package wire import ( "io" + + "github.com/daglabs/btcd/util/subnetworkid" ) // MsgGetAddr implements the Message interface and represents a bitcoin @@ -14,17 +16,46 @@ import ( // via one or more addr messages (MsgAddr). // // This message has no payload. -type MsgGetAddr struct{} +type MsgGetAddr struct { + SubnetworkID *subnetworkid.SubnetworkID +} // BtcDecode decodes r using the bitcoin protocol encoding into the receiver. // This is part of the Message interface implementation. func (msg *MsgGetAddr) BtcDecode(r io.Reader, pver uint32) error { + var isAllSubnetworks bool + err := readElement(r, &isAllSubnetworks) + if err != nil { + return err + } + if isAllSubnetworks { + msg.SubnetworkID = nil + } else { + var subnetworkID subnetworkid.SubnetworkID + err = readElement(r, &subnetworkID) + if err != nil { + return err + } + msg.SubnetworkID = &subnetworkID + } + return nil } // BtcEncode encodes the receiver to w using the bitcoin protocol encoding. // This is part of the Message interface implementation. func (msg *MsgGetAddr) BtcEncode(w io.Writer, pver uint32) error { + isAllSubnetworks := msg.SubnetworkID == nil + err := writeElement(w, isAllSubnetworks) + if err != nil { + return err + } + if !isAllSubnetworks { + err = writeElement(w, msg.SubnetworkID) + if err != nil { + return err + } + } return nil } @@ -37,11 +68,13 @@ func (msg *MsgGetAddr) Command() string { // MaxPayloadLength returns the maximum length the payload can be for the // receiver. This is part of the Message interface implementation. func (msg *MsgGetAddr) MaxPayloadLength(pver uint32) uint32 { - return 0 + return subnetworkid.IDLength + 1 } // NewMsgGetAddr returns a new bitcoin getaddr message that conforms to the // Message interface. See MsgGetAddr for details. -func NewMsgGetAddr() *MsgGetAddr { - return &MsgGetAddr{} +func NewMsgGetAddr(subnetworkID *subnetworkid.SubnetworkID) *MsgGetAddr { + return &MsgGetAddr{ + SubnetworkID: subnetworkID, + } } diff --git a/wire/msggetaddr_test.go b/wire/msggetaddr_test.go index 04586b5e1..1f17ac2fc 100644 --- a/wire/msggetaddr_test.go +++ b/wire/msggetaddr_test.go @@ -18,7 +18,7 @@ func TestGetAddr(t *testing.T) { // Ensure the command is expected value. wantCmd := "getaddr" - msg := NewMsgGetAddr() + msg := NewMsgGetAddr(nil) if cmd := msg.Command(); cmd != wantCmd { t.Errorf("NewMsgGetAddr: wrong command - got %v want %v", cmd, wantCmd) @@ -26,7 +26,7 @@ func TestGetAddr(t *testing.T) { // Ensure max payload is expected value for latest protocol version. // Num addresses (varInt) + max allowed addresses. - wantPayload := uint32(0) + wantPayload := uint32(21) maxPayload := msg.MaxPayloadLength(pver) if maxPayload != wantPayload { t.Errorf("MaxPayloadLength: wrong max payload length for "+ @@ -38,8 +38,20 @@ func TestGetAddr(t *testing.T) { // TestGetAddrWire tests the MsgGetAddr wire encode and decode for various // protocol versions. func TestGetAddrWire(t *testing.T) { - msgGetAddr := NewMsgGetAddr() - msgGetAddrEncoded := []byte{} + // With all subnetworks + msgGetAddr := NewMsgGetAddr(nil) + msgGetAddrEncoded := []byte{ + 0x01, // All subnetworks + } + + // With specific subnetwork + msgGetAddrSubnet := NewMsgGetAddr(&SubnetworkIDNative) + msgGetAddrSubnetEncoded := []byte{ + 0x00, // All subnetworks + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // Subnetwork ID + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + } tests := []struct { in *MsgGetAddr // Message to encode @@ -47,13 +59,20 @@ func TestGetAddrWire(t *testing.T) { buf []byte // Wire encoding pver uint32 // Protocol version for wire encoding }{ - // Latest protocol version. + // Latest protocol version. All subnetworks { msgGetAddr, msgGetAddr, msgGetAddrEncoded, ProtocolVersion, }, + // Latest protocol version. Specific subnetwork + { + msgGetAddrSubnet, + msgGetAddrSubnet, + msgGetAddrSubnetEncoded, + ProtocolVersion, + }, } t.Logf("Running %d tests", len(tests))