diff --git a/netadapter/router/router.go b/netadapter/router/router.go index cdcef27a8..72e19b389 100644 --- a/netadapter/router/router.go +++ b/netadapter/router/router.go @@ -12,7 +12,7 @@ type OnRouteCapacityReachedHandler func() // Router routes messages by type to their respective // input channels type Router struct { - incomingRoutes map[string]*Route + incomingRoutes map[wire.MessageCommand]*Route outgoingRoute *Route onRouteCapacityReachedHandler OnRouteCapacityReachedHandler @@ -21,7 +21,7 @@ type Router struct { // NewRouter creates a new empty router func NewRouter() *Router { router := Router{ - incomingRoutes: make(map[string]*Route), + incomingRoutes: make(map[wire.MessageCommand]*Route), outgoingRoute: NewRoute(), } router.outgoingRoute.setOnCapacityReachedHandler(func() { @@ -38,7 +38,7 @@ func (r *Router) SetOnRouteCapacityReachedHandler(onRouteCapacityReachedHandler // AddIncomingRoute registers the messages of types `messageTypes` to // be routed to the given `route` -func (r *Router) AddIncomingRoute(messageTypes []string) (*Route, error) { +func (r *Router) AddIncomingRoute(messageTypes []wire.MessageCommand) (*Route, error) { route := NewRoute() for _, messageType := range messageTypes { if _, ok := r.incomingRoutes[messageType]; ok { @@ -54,7 +54,7 @@ func (r *Router) AddIncomingRoute(messageTypes []string) (*Route, error) { // RemoveRoute unregisters the messages of types `messageTypes` from // the router -func (r *Router) RemoveRoute(messageTypes []string) error { +func (r *Router) RemoveRoute(messageTypes []wire.MessageCommand) error { for _, messageType := range messageTypes { if _, ok := r.incomingRoutes[messageType]; !ok { return errors.Errorf("a route for '%s' does not exist", messageType) diff --git a/netadapter/server/grpcserver/protowire/messages.pb.go b/netadapter/server/grpcserver/protowire/messages.pb.go index f4c8416b4..1e5a31b72 100644 --- a/netadapter/server/grpcserver/protowire/messages.pb.go +++ b/netadapter/server/grpcserver/protowire/messages.pb.go @@ -30,7 +30,7 @@ type KaspadMessage struct { sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - Command string `protobuf:"bytes,1,opt,name=command,proto3" json:"command,omitempty"` + Command uint32 `protobuf:"varint,1,opt,name=command,proto3" json:"command,omitempty"` Payload []byte `protobuf:"bytes,2,opt,name=payload,proto3" json:"payload,omitempty"` } @@ -66,11 +66,11 @@ func (*KaspadMessage) Descriptor() ([]byte, []int) { return file_messages_proto_rawDescGZIP(), []int{0} } -func (x *KaspadMessage) GetCommand() string { +func (x *KaspadMessage) GetCommand() uint32 { if x != nil { return x.Command } - return "" + return 0 } func (x *KaspadMessage) GetPayload() []byte { @@ -86,7 +86,7 @@ var file_messages_proto_rawDesc = []byte{ 0x0a, 0x0e, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x09, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x77, 0x69, 0x72, 0x65, 0x22, 0x43, 0x0a, 0x0d, 0x4b, 0x61, 0x73, 0x70, 0x61, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x18, 0x0a, 0x07, - 0x63, 0x6f, 0x6d, 0x6d, 0x61, 0x6e, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x63, + 0x63, 0x6f, 0x6d, 0x6d, 0x61, 0x6e, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x07, 0x63, 0x6f, 0x6d, 0x6d, 0x61, 0x6e, 0x64, 0x12, 0x18, 0x0a, 0x07, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x07, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x32, 0x50, 0x0a, 0x03, 0x50, 0x32, 0x50, 0x12, 0x49, 0x0a, 0x0d, 0x4d, 0x65, 0x73, 0x73, 0x61, diff --git a/netadapter/server/grpcserver/protowire/messages.proto b/netadapter/server/grpcserver/protowire/messages.proto index 9ba2a31c9..dd9dbb59e 100644 --- a/netadapter/server/grpcserver/protowire/messages.proto +++ b/netadapter/server/grpcserver/protowire/messages.proto @@ -4,10 +4,10 @@ package protowire; option go_package = "github.com/kaspanet/kaspad/protowire"; message KaspadMessage{ - string command = 1; - bytes payload = 2; + uint32 command = 1; + bytes payload = 2; } service P2P { - rpc MessageStream (stream KaspadMessage) returns (stream KaspadMessage) {} + rpc MessageStream (stream KaspadMessage) returns (stream KaspadMessage) {} } diff --git a/netadapter/server/grpcserver/protowire/wire.go b/netadapter/server/grpcserver/protowire/wire.go index 541446720..6a287f778 100644 --- a/netadapter/server/grpcserver/protowire/wire.go +++ b/netadapter/server/grpcserver/protowire/wire.go @@ -8,7 +8,7 @@ import ( // ToWireMessage converts a KaspadMessage to its wire.Message representation func (x *KaspadMessage) ToWireMessage() (wire.Message, error) { - message, err := wire.MakeEmptyMessage(x.Command) + message, err := wire.MakeEmptyMessage(wire.MessageCommand(x.Command)) if err != nil { return nil, err } @@ -32,7 +32,7 @@ func FromWireMessage(message wire.Message) (*KaspadMessage, error) { } return &KaspadMessage{ - Command: message.Command(), + Command: uint32(message.Command()), Payload: payloadWriter.Bytes(), }, nil } diff --git a/peer/message_logging.go b/peer/message_logging.go index c92937d1d..84cd2b536 100644 --- a/peer/message_logging.go +++ b/peer/message_logging.go @@ -150,11 +150,10 @@ func messageSummary(msg wire.Message) string { // characters which are even remotely dangerous such as HTML // control characters, etc. Also limit them to sane length for // logging. - rejCommand := sanitizeString(msg.Cmd, wire.CommandSize) rejReason := sanitizeString(msg.Reason, maxRejectReasonLen) - summary := fmt.Sprintf("cmd %s, code %s, reason %s", rejCommand, + summary := fmt.Sprintf("cmd %s, code %s, reason %s", msg.Cmd, msg.Code, rejReason) - if rejCommand == wire.CmdBlock || rejCommand == wire.CmdTx { + if msg.Cmd == wire.CmdBlock || msg.Cmd == wire.CmdTx { summary += fmt.Sprintf(", hash %s", msg.Hash) } return summary diff --git a/peer/peer.go b/peer/peer.go index df7cff53d..1550ae52e 100644 --- a/peer/peer.go +++ b/peer/peer.go @@ -657,7 +657,7 @@ func (p *Peer) AddBanScore(persistent, transient uint32, reason string) { // AddBanScoreAndPushRejectMsg increases ban score and sends a // reject message to the misbehaving peer. -func (p *Peer) AddBanScoreAndPushRejectMsg(command string, code wire.RejectCode, hash *daghash.Hash, persistent, transient uint32, reason string) { +func (p *Peer) AddBanScoreAndPushRejectMsg(command wire.MessageCommand, code wire.RejectCode, hash *daghash.Hash, persistent, transient uint32, reason string) { p.PushRejectMsg(command, code, reason, hash, true) p.cfg.AddBanScore(persistent, transient, reason) } @@ -837,7 +837,7 @@ func (p *Peer) PushBlockLocatorMsg(locator blockdag.BlockLocator) error { // function to block until the reject message has actually been sent. // // This function is safe for concurrent access. -func (p *Peer) PushRejectMsg(command string, code wire.RejectCode, reason string, hash *daghash.Hash, wait bool) { +func (p *Peer) PushRejectMsg(command wire.MessageCommand, code wire.RejectCode, reason string, hash *daghash.Hash, wait bool) { msg := wire.NewMsgReject(command, code, reason) if command == wire.CmdTx || command == wire.CmdBlock { if hash == nil { @@ -1094,7 +1094,7 @@ func (p *Peer) shouldHandleReadError(err error) bool { // maybeAddDeadline potentially adds a deadline for the appropriate expected // response for the passed wire protocol command to the pending responses map. -func (p *Peer) maybeAddDeadline(pendingResponses map[string]time.Time, msgCmd string) { +func (p *Peer) maybeAddDeadline(pendingResponses map[wire.MessageCommand]time.Time, msgCmd wire.MessageCommand) { // Setup a deadline for each message being sent that expects a response. // // NOTE: Pings are intentionally ignored here since they are typically @@ -1138,7 +1138,7 @@ func (p *Peer) stallHandler() { var deadlineOffset time.Duration // pendingResponses tracks the expected response deadline times. - pendingResponses := make(map[string]time.Time) + pendingResponses := make(map[wire.MessageCommand]time.Time) // stallTicker is used to periodically check pending responses that have // exceeded the expected deadline and disconnect the peer due to @@ -1313,7 +1313,7 @@ out: // at least that much of the message was valid, but that is not // currently exposed by wire, so just used malformed for the // command. - p.AddBanScoreAndPushRejectMsg("malformed", wire.RejectMalformed, nil, + p.AddBanScoreAndPushRejectMsg(wire.CmdRejectMalformed, wire.RejectMalformed, nil, BanScoreMalformedMessage, 0, errMsg) } break out diff --git a/protocol/handshake.go b/protocol/handshake.go index e56a13792..bf9a55fcc 100644 --- a/protocol/handshake.go +++ b/protocol/handshake.go @@ -18,12 +18,12 @@ import ( func handshake(router *routerpkg.Router, netAdapter *netadapter.NetAdapter, peer *peerpkg.Peer, dag *blockdag.BlockDAG, addressManager *addrmgr.AddrManager) (closed bool, err error) { - receiveVersionRoute, err := router.AddIncomingRoute([]string{wire.CmdVersion}) + receiveVersionRoute, err := router.AddIncomingRoute([]wire.MessageCommand{wire.CmdVersion}) if err != nil { panic(err) } - sendVersionRoute, err := router.AddIncomingRoute([]string{wire.CmdVerAck}) + sendVersionRoute, err := router.AddIncomingRoute([]wire.MessageCommand{wire.CmdVerAck}) if err != nil { panic(err) } @@ -102,7 +102,7 @@ func handshake(router *routerpkg.Router, netAdapter *netadapter.NetAdapter, peer addressManager.AddAddress(peerAddress, peerAddress, subnetworkID) } - err = router.RemoveRoute([]string{wire.CmdVersion, wire.CmdVerAck}) + err = router.RemoveRoute([]wire.MessageCommand{wire.CmdVersion, wire.CmdVerAck}) if err != nil { panic(err) } diff --git a/protocol/protocol.go b/protocol/protocol.go index 3ab060bcd..bfdac2f6c 100644 --- a/protocol/protocol.go +++ b/protocol/protocol.go @@ -99,37 +99,37 @@ func startFlows(netAdapter *netadapter.NetAdapter, router *routerpkg.Router, dag return nil } - addOneTimeFlow("SendAddresses", router, []string{wire.CmdGetAddresses}, &stopped, stop, + addOneTimeFlow("SendAddresses", router, []wire.MessageCommand{wire.CmdGetAddresses}, &stopped, stop, func(incomingRoute *routerpkg.Route) (routeClosed bool, err error) { return sendaddresses.SendAddresses(incomingRoute, outgoingRoute, addressManager) }, ) - addOneTimeFlow("ReceiveAddresses", router, []string{wire.CmdAddress}, &stopped, stop, + addOneTimeFlow("ReceiveAddresses", router, []wire.MessageCommand{wire.CmdAddress}, &stopped, stop, func(incomingRoute *routerpkg.Route) (routeClosed bool, err error) { return receiveaddresses.ReceiveAddresses(incomingRoute, outgoingRoute, peer, addressManager) }, ) - addFlow("HandleRelayInvs", router, []string{wire.CmdInvRelayBlock, wire.CmdBlock}, &stopped, stop, + addFlow("HandleRelayInvs", router, []wire.MessageCommand{wire.CmdInvRelayBlock, wire.CmdBlock}, &stopped, stop, func(incomingRoute *routerpkg.Route) error { return handlerelayinvs.HandleRelayInvs(incomingRoute, outgoingRoute, peer, netAdapter, dag) }, ) - addFlow("HandleRelayBlockRequests", router, []string{wire.CmdGetRelayBlocks}, &stopped, stop, + addFlow("HandleRelayBlockRequests", router, []wire.MessageCommand{wire.CmdGetRelayBlocks}, &stopped, stop, func(incomingRoute *routerpkg.Route) error { return handlerelayblockrequests.HandleRelayBlockRequests(incomingRoute, outgoingRoute, peer, dag) }, ) - addFlow("ReceivePings", router, []string{wire.CmdPing}, &stopped, stop, + addFlow("ReceivePings", router, []wire.MessageCommand{wire.CmdPing}, &stopped, stop, func(incomingRoute *routerpkg.Route) error { return ping.ReceivePings(incomingRoute, outgoingRoute) }, ) - addFlow("SendPings", router, []string{wire.CmdPong}, &stopped, stop, + addFlow("SendPings", router, []wire.MessageCommand{wire.CmdPong}, &stopped, stop, func(incomingRoute *routerpkg.Route) error { return ping.SendPings(incomingRoute, outgoingRoute, peer) }, @@ -139,7 +139,7 @@ func startFlows(netAdapter *netadapter.NetAdapter, router *routerpkg.Router, dag return err } -func addFlow(name string, router *routerpkg.Router, messageTypes []string, stopped *uint32, +func addFlow(name string, router *routerpkg.Router, messageTypes []wire.MessageCommand, stopped *uint32, stopChan chan error, flow func(route *routerpkg.Route) error) { route, err := router.AddIncomingRoute(messageTypes) @@ -158,7 +158,7 @@ func addFlow(name string, router *routerpkg.Router, messageTypes []string, stopp }) } -func addOneTimeFlow(name string, router *routerpkg.Router, messageTypes []string, stopped *uint32, +func addOneTimeFlow(name string, router *routerpkg.Router, messageTypes []wire.MessageCommand, stopped *uint32, stopChan chan error, flow func(route *routerpkg.Route) (routeClosed bool, err error)) { route, err := router.AddIncomingRoute(messageTypes) diff --git a/server/p2p/p2p.go b/server/p2p/p2p.go index b71cb5d3e..8f13c733c 100644 --- a/server/p2p/p2p.go +++ b/server/p2p/p2p.go @@ -367,7 +367,7 @@ func (sp *Peer) addBanScore(persistent, transient uint32, reason string) { // allow bloom filters. Additionally, if the peer has negotiated to a protocol // version that is high enough to observe the bloom filter service support bit, // it will be banned since it is intentionally violating the protocol. -func (sp *Peer) enforceNodeBloomFlag(cmd string) bool { +func (sp *Peer) enforceNodeBloomFlag(cmd wire.MessageCommand) bool { if sp.server.services&wire.SFNodeBloom != wire.SFNodeBloom { // NOTE: Even though the addBanScore function already examines // whether or not banning is enabled, it is checked here as well @@ -377,7 +377,7 @@ func (sp *Peer) enforceNodeBloomFlag(cmd string) bool { // Disconnect the peer regardless of whether it was // banned. - sp.addBanScore(peer.BanScoreNodeBloomFlagViolation, 0, cmd) + sp.addBanScore(peer.BanScoreNodeBloomFlagViolation, 0, cmd.String()) sp.Disconnect() return false } diff --git a/wire/common.go b/wire/common.go index 20fc63c7a..ff1b7e7c1 100644 --- a/wire/common.go +++ b/wire/common.go @@ -12,6 +12,7 @@ import ( "github.com/kaspanet/kaspad/util/daghash" "github.com/kaspanet/kaspad/util/mstime" "github.com/kaspanet/kaspad/util/subnetworkid" + "github.com/pkg/errors" "io" "math" ) @@ -34,6 +35,9 @@ var ( var errNonCanonicalVarInt = "non-canonical varint %x - discriminant %x must " + "encode a value greater than %x" +// errNoEncodingForType signifies that there's no encoding for the given type. +var errNoEncodingForType = errors.New("there's no encoding for this type") + // int64Time represents a unix timestamp with milliseconds precision encoded with // an int64. It is used as a way to signal the readElement function how to decode // a timestamp into a Go mstime.Time since it is otherwise ambiguous. @@ -77,6 +81,14 @@ func ReadElement(r io.Reader, element interface{}) error { *e = rv return nil + case *uint8: + rv, err := binaryserializer.Uint8(r) + if err != nil { + return err + } + *e = rv + return nil + case *bool: rv, err := binaryserializer.Uint8(r) if err != nil { @@ -107,11 +119,12 @@ func ReadElement(r io.Reader, element interface{}) error { return nil // Message header command. - case *[CommandSize]uint8: - _, err := io.ReadFull(r, e[:]) + case *MessageCommand: + rv, err := binaryserializer.Uint32(r, littleEndian) if err != nil { return err } + *e = MessageCommand(rv) return nil // IP address. @@ -180,9 +193,7 @@ func ReadElement(r io.Reader, element interface{}) error { return nil } - // Fall back to the slower binary.Read if a fast path was not available - // above. - return binary.Read(r, littleEndian, element) + return errors.Wrapf(errNoEncodingForType, "couldn't find a way to read type %T", element) } // readElements reads multiple items from r. It is equivalent to multiple @@ -230,6 +241,13 @@ func WriteElement(w io.Writer, element interface{}) error { } return nil + case uint8: + err := binaryserializer.PutUint8(w, e) + if err != nil { + return err + } + return nil + case bool: var err error if e { @@ -251,8 +269,8 @@ func WriteElement(w io.Writer, element interface{}) error { return nil // Message header command. - case [CommandSize]uint8: - _, err := w.Write(e[:]) + case MessageCommand: + err := binaryserializer.PutUint32(w, littleEndian, uint32(e)) if err != nil { return err } @@ -319,9 +337,7 @@ func WriteElement(w io.Writer, element interface{}) error { return nil } - // Fall back to the slower binary.Write if a fast path was not available - // above. - return binary.Write(w, littleEndian, element) + return errors.Wrapf(errNoEncodingForType, "couldn't find a way to write type %T", element) } // writeElements writes multiple items to w. It is equivalent to multiple diff --git a/wire/common_test.go b/wire/common_test.go index 9594bd8f0..a5a40fc27 100644 --- a/wire/common_test.go +++ b/wire/common_test.go @@ -61,8 +61,6 @@ var exampleUTXOCommitment = &daghash.Hash{ // is mainly to test the "fast" paths in readElement and writeElement which use // type assertions to avoid reflection when possible. func TestElementWire(t *testing.T) { - type writeElementReflect int32 - tests := []struct { in interface{} // Value to encode buf []byte // Wire encoding @@ -90,13 +88,9 @@ func TestElementWire(t *testing.T) { []byte{0x01, 0x02, 0x03, 0x04}, }, { - [CommandSize]byte{ - 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, - 0x09, 0x0a, 0x0b, 0x0c, - }, + MessageCommand(0x10), []byte{ - 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, - 0x09, 0x0a, 0x0b, 0x0c, + 0x10, 0x00, 0x00, 0x00, }, }, { @@ -135,11 +129,6 @@ func TestElementWire(t *testing.T) { KaspaNet(Mainnet), []byte{0x1d, 0xf7, 0xdc, 0x3d}, }, - // Type not supported by the "fast" path and requires reflection. - { - writeElementReflect(1), - []byte{0x01, 0x00, 0x00, 0x00}, - }, } t.Logf("Running %d tests", len(tests)) @@ -183,6 +172,8 @@ func TestElementWire(t *testing.T) { // TestElementWireErrors performs negative tests against wire encode and decode // of various element types to confirm error paths work correctly. func TestElementWireErrors(t *testing.T) { + type writeElementReflect int32 + tests := []struct { in interface{} // Value to encode max int // Max size of fixed buffer to induce errors @@ -195,10 +186,7 @@ func TestElementWireErrors(t *testing.T) { {true, 0, io.ErrShortWrite, io.EOF}, {[4]byte{0x01, 0x02, 0x03, 0x04}, 0, io.ErrShortWrite, io.EOF}, { - [CommandSize]byte{ - 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, - 0x09, 0x0a, 0x0b, 0x0c, - }, + MessageCommand(10), 0, io.ErrShortWrite, io.EOF, }, { @@ -220,6 +208,8 @@ func TestElementWireErrors(t *testing.T) { {ServiceFlag(SFNodeNetwork), 0, io.ErrShortWrite, io.EOF}, {InvType(InvTypeTx), 0, io.ErrShortWrite, io.EOF}, {KaspaNet(Mainnet), 0, io.ErrShortWrite, io.EOF}, + // Type with no supported encoding. + {writeElementReflect(0), 0, errNoEncodingForType, errNoEncodingForType}, } t.Logf("Running %d tests", len(tests)) diff --git a/wire/fakemessage_test.go b/wire/fakemessage_test.go index a5d074713..94a8f2839 100644 --- a/wire/fakemessage_test.go +++ b/wire/fakemessage_test.go @@ -9,7 +9,7 @@ import "io" // fakeMessage implements the Message interface and is used to force encode // errors in messages. type fakeMessage struct { - command string + command MessageCommand payload []byte forceEncodeErr bool forceLenErr bool @@ -39,7 +39,7 @@ func (msg *fakeMessage) KaspaEncode(w io.Writer, pver uint32) error { // Command returns the command field of the fake message and satisfies the // Message interface. -func (msg *fakeMessage) Command() string { +func (msg *fakeMessage) Command() MessageCommand { return msg.command } diff --git a/wire/message.go b/wire/message.go index 96c5ab843..beb15da9a 100644 --- a/wire/message.go +++ b/wire/message.go @@ -7,55 +7,89 @@ package wire import ( "bytes" "fmt" - "io" - "unicode/utf8" - "github.com/pkg/errors" + "io" "github.com/kaspanet/kaspad/util/daghash" ) // MessageHeaderSize is the number of bytes in a kaspa message header. -// Kaspa network (magic) 4 bytes + command 12 bytes + payload length 4 bytes + +// Kaspa network (magic) 4 bytes + command 4 byte + payload length 4 bytes + // checksum 4 bytes. -const MessageHeaderSize = 24 - -// CommandSize is the fixed size of all commands in the common kaspa message -// header. Shorter commands must be zero padded. -const CommandSize = 12 +const MessageHeaderSize = 16 // MaxMessagePayload is the maximum bytes a message can be regardless of other // individual limits imposed by messages themselves. const MaxMessagePayload = (1024 * 1024 * 32) // 32MB +// MessageCommand is a number in the header of a message that represents its type. +type MessageCommand uint32 + +func (cmd MessageCommand) String() string { + cmdString, ok := messageCommandToString[cmd] + if !ok { + cmdString = "unknown command" + } + return fmt.Sprintf("%s [code %d]", cmdString, uint8(cmd)) +} + // Commands used in kaspa message headers which describe the type of message. const ( - CmdVersion = "version" - CmdVerAck = "verack" - CmdGetAddresses = "getaddr" - CmdAddress = "addr" - CmdGetBlockInvs = "getblockinvs" - CmdInv = "inv" - CmdGetData = "getdata" - CmdNotFound = "notfound" - CmdBlock = "block" - CmdTx = "tx" - CmdPing = "ping" - CmdPong = "pong" - CmdFilterAdd = "filteradd" - CmdFilterClear = "filterclear" - CmdFilterLoad = "filterload" - CmdMerkleBlock = "merkleblock" - CmdReject = "reject" - CmdFeeFilter = "feefilter" - CmdGetBlockLocator = "getlocator" - CmdBlockLocator = "locator" - CmdSelectedTip = "selectedtip" - CmdGetSelectedTip = "getseltip" - CmdInvRelayBlock = "invrelblk" - CmdGetRelayBlocks = "getrelblks" + CmdVersion MessageCommand = 0 + CmdVerAck MessageCommand = 1 + CmdGetAddresses MessageCommand = 2 + CmdAddress MessageCommand = 3 + CmdGetBlockInvs MessageCommand = 4 + CmdInv MessageCommand = 5 + CmdGetData MessageCommand = 6 + CmdNotFound MessageCommand = 7 + CmdBlock MessageCommand = 8 + CmdTx MessageCommand = 9 + CmdPing MessageCommand = 10 + CmdPong MessageCommand = 11 + CmdFilterAdd MessageCommand = 12 + CmdFilterClear MessageCommand = 13 + CmdFilterLoad MessageCommand = 14 + CmdMerkleBlock MessageCommand = 15 + CmdReject MessageCommand = 16 + CmdFeeFilter MessageCommand = 17 + CmdGetBlockLocator MessageCommand = 18 + CmdBlockLocator MessageCommand = 19 + CmdSelectedTip MessageCommand = 20 + CmdGetSelectedTip MessageCommand = 21 + CmdInvRelayBlock MessageCommand = 22 + CmdGetRelayBlocks MessageCommand = 23 + CmdRejectMalformed MessageCommand = 24 // Used only for reject message ) +var messageCommandToString = map[MessageCommand]string{ + CmdVersion: "Version", + CmdVerAck: "VerAck", + CmdGetAddresses: "GetAddr", + CmdAddress: "Addr", + CmdGetBlockInvs: "GetBlockInvs", + CmdInv: "Inv", + CmdGetData: "GetData", + CmdNotFound: "NotFound", + CmdBlock: "Block", + CmdTx: "Tx", + CmdPing: "Ping", + CmdPong: "Pong", + CmdFilterAdd: "FilterAdd", + CmdFilterClear: "FilterClear", + CmdFilterLoad: "FilterLoad", + CmdMerkleBlock: "MerkleBlock", + CmdReject: "Reject", + CmdFeeFilter: "FeeFilter", + CmdGetBlockLocator: "GetBlockLocator", + CmdBlockLocator: "BlockLocator", + CmdSelectedTip: "SelectedTip", + CmdGetSelectedTip: "GetSelectedTip", + CmdInvRelayBlock: "InvRelayBlock", + CmdGetRelayBlocks: "GetRelayBlocks", + CmdRejectMalformed: "RejectMalformed", +} + // Message is an interface that describes a kaspa message. A type that // implements Message has complete control over the representation of its data // and may therefore contain additional or fewer fields than those which @@ -63,13 +97,13 @@ const ( type Message interface { KaspaDecode(io.Reader, uint32) error KaspaEncode(io.Writer, uint32) error - Command() string + Command() MessageCommand MaxPayloadLength(uint32) uint32 } // MakeEmptyMessage creates a message of the appropriate concrete type based // on the command. -func MakeEmptyMessage(command string) (Message, error) { +func MakeEmptyMessage(command MessageCommand) (Message, error) { var msg Message switch command { case CmdVersion: @@ -146,10 +180,10 @@ func MakeEmptyMessage(command string) (Message, error) { // messageHeader defines the header structure for all kaspa protocol messages. type messageHeader struct { - magic KaspaNet // 4 bytes - command string // 12 bytes - length uint32 // 4 bytes - checksum [4]byte // 4 bytes + magic KaspaNet // 4 bytes + command MessageCommand // 4 bytes + length uint32 // 4 bytes + checksum [4]byte // 4 bytes } // readMessageHeader reads a kaspa message header from r. @@ -167,15 +201,11 @@ func readMessageHeader(r io.Reader) (int, *messageHeader, error) { // Create and populate a messageHeader struct from the raw header bytes. hdr := messageHeader{} - var command [CommandSize]byte - err = readElements(hr, &hdr.magic, &command, &hdr.length, &hdr.checksum) + err = readElements(hr, &hdr.magic, &hdr.command, &hdr.length, &hdr.checksum) if err != nil { return 0, nil, err } - // Strip trailing zeros from command string. - hdr.command = string(bytes.TrimRight(command[:], string(0))) - return n, &hdr, nil } @@ -205,16 +235,6 @@ func discardInput(r io.Reader, n uint32) { func WriteMessageN(w io.Writer, msg Message, pver uint32, kaspaNet KaspaNet) (int, error) { totalBytes := 0 - // Enforce max command size. - var command [CommandSize]byte - cmd := msg.Command() - if len(cmd) > CommandSize { - str := fmt.Sprintf("command [%s] is too long [max %d]", - cmd, CommandSize) - return totalBytes, messageError("WriteMessage", str) - } - copy(command[:], []byte(cmd)) - // Encode the message payload. var bw bytes.Buffer err := msg.KaspaEncode(&bw, pver) @@ -237,14 +257,14 @@ func WriteMessageN(w io.Writer, msg Message, pver uint32, kaspaNet KaspaNet) (in if uint32(lenp) > mpl { str := fmt.Sprintf("message payload is too large - encoded "+ "%d bytes, but maximum message payload size for "+ - "messages of type [%s] is %d.", lenp, cmd, mpl) + "messages of type [%s] is %d.", lenp, msg.Command(), mpl) return totalBytes, messageError("WriteMessage", str) } // Create header for the message. hdr := messageHeader{} hdr.magic = kaspaNet - hdr.command = cmd + hdr.command = msg.Command() hdr.length = uint32(lenp) copy(hdr.checksum[:], daghash.DoubleHashB(payload)[0:4]) @@ -252,7 +272,7 @@ func WriteMessageN(w io.Writer, msg Message, pver uint32, kaspaNet KaspaNet) (in // rather than directly to the writer since writeElements doesn't // return the number of bytes written. hw := bytes.NewBuffer(make([]byte, 0, MessageHeaderSize)) - err = writeElements(hw, hdr.magic, command, hdr.length, hdr.checksum) + err = writeElements(hw, hdr.magic, hdr.command, hdr.length, hdr.checksum) if err != nil { return 0, err } @@ -309,14 +329,7 @@ func ReadMessageN(r io.Reader, pver uint32, kaspaNet KaspaNet) (int, Message, [] return totalBytes, nil, nil, messageError("ReadMessage", str) } - // Check for malformed commands. command := hdr.command - if !utf8.ValidString(command) { - discardInput(r, hdr.length) - str := fmt.Sprintf("invalid command %d", []byte(command)) - return totalBytes, nil, nil, messageError("ReadMessage", str) - } - // Create struct of appropriate message type based on the command. msg, err := MakeEmptyMessage(command) if err != nil { diff --git a/wire/message_test.go b/wire/message_test.go index 8ecc78d29..639314c41 100644 --- a/wire/message_test.go +++ b/wire/message_test.go @@ -11,6 +11,7 @@ import ( "github.com/kaspanet/kaspad/util/mstime" "github.com/pkg/errors" "io" + "math" "net" "reflect" "testing" @@ -21,17 +22,17 @@ import ( // makeHeader is a convenience function to make a message header in the form of // a byte slice. It is used to force errors when reading messages. -func makeHeader(kaspaNet KaspaNet, command string, +func makeHeader(kaspaNet KaspaNet, command MessageCommand, payloadLen uint32, checksum uint32) []byte { - // The length of a kaspa message header is 24 bytes. - // 4 byte magic number of the kaspa network + 12 byte command + 4 byte + // The length of a kaspa message header is 13 bytes. + // 4 byte magic number of the kaspa network + 4 bytes command + 4 byte // payload length + 4 byte checksum. - buf := make([]byte, 24) + buf := make([]byte, 16) binary.LittleEndian.PutUint32(buf, uint32(kaspaNet)) - copy(buf[4:], []byte(command)) - binary.LittleEndian.PutUint32(buf[16:], payloadLen) - binary.LittleEndian.PutUint32(buf[20:], checksum) + binary.LittleEndian.PutUint32(buf[4:], uint32(command)) + binary.LittleEndian.PutUint32(buf[8:], payloadLen) + binary.LittleEndian.PutUint32(buf[12:], checksum) return buf } @@ -72,7 +73,7 @@ func TestMessage(t *testing.T) { msgFilterLoad := NewMsgFilterLoad([]byte{0x01}, 10, 0, BloomUpdateNone) bh := NewBlockHeader(1, []*daghash.Hash{mainnetGenesisHash, simnetGenesisHash}, &daghash.Hash{}, &daghash.Hash{}, &daghash.Hash{}, 0, 0) msgMerkleBlock := NewMsgMerkleBlock(bh) - msgReject := NewMsgReject("block", RejectDuplicate, "duplicate block") + msgReject := NewMsgReject(CmdBlock, RejectDuplicate, "duplicate block") tests := []struct { in Message // Value to encode @@ -81,26 +82,26 @@ func TestMessage(t *testing.T) { kaspaNet KaspaNet // Network to use for wire encoding bytes int // Expected num bytes read/written }{ - {msgVersion, msgVersion, pver, Mainnet, 136}, - {msgVerack, msgVerack, pver, Mainnet, 24}, - {msgGetAddresses, msgGetAddresses, pver, Mainnet, 26}, - {msgAddresses, msgAddresses, pver, Mainnet, 27}, - {msgGetBlockInvs, msgGetBlockInvs, pver, Mainnet, 88}, - {msgBlock, msgBlock, pver, Mainnet, 372}, - {msgInv, msgInv, pver, Mainnet, 25}, - {msgGetData, msgGetData, pver, Mainnet, 25}, - {msgNotFound, msgNotFound, pver, Mainnet, 25}, - {msgTx, msgTx, pver, Mainnet, 58}, - {msgPing, msgPing, pver, Mainnet, 32}, - {msgPong, msgPong, pver, Mainnet, 32}, - {msgGetBlockLocator, msgGetBlockLocator, pver, Mainnet, 88}, - {msgBlockLocator, msgBlockLocator, pver, Mainnet, 25}, - {msgFeeFilter, msgFeeFilter, pver, Mainnet, 32}, - {msgFilterAdd, msgFilterAdd, pver, Mainnet, 26}, - {msgFilterClear, msgFilterClear, pver, Mainnet, 24}, - {msgFilterLoad, msgFilterLoad, pver, Mainnet, 35}, - {msgMerkleBlock, msgMerkleBlock, pver, Mainnet, 215}, - {msgReject, msgReject, pver, Mainnet, 79}, + {msgVersion, msgVersion, pver, Mainnet, 128}, + {msgVerack, msgVerack, pver, Mainnet, 16}, + {msgGetAddresses, msgGetAddresses, pver, Mainnet, 18}, + {msgAddresses, msgAddresses, pver, Mainnet, 19}, + {msgGetBlockInvs, msgGetBlockInvs, pver, Mainnet, 80}, + {msgBlock, msgBlock, pver, Mainnet, 364}, + {msgInv, msgInv, pver, Mainnet, 17}, + {msgGetData, msgGetData, pver, Mainnet, 17}, + {msgNotFound, msgNotFound, pver, Mainnet, 17}, + {msgTx, msgTx, pver, Mainnet, 50}, + {msgPing, msgPing, pver, Mainnet, 24}, + {msgPong, msgPong, pver, Mainnet, 24}, + {msgGetBlockLocator, msgGetBlockLocator, pver, Mainnet, 80}, + {msgBlockLocator, msgBlockLocator, pver, Mainnet, 17}, + {msgFeeFilter, msgFeeFilter, pver, Mainnet, 24}, + {msgFilterAdd, msgFilterAdd, pver, Mainnet, 18}, + {msgFilterClear, msgFilterClear, pver, Mainnet, 16}, + {msgFilterLoad, msgFilterLoad, pver, Mainnet, 27}, + {msgMerkleBlock, msgMerkleBlock, pver, Mainnet, 207}, + {msgReject, msgReject, pver, Mainnet, 69}, } t.Logf("Running %d tests", len(tests)) @@ -190,31 +191,33 @@ func TestReadMessageWireErrors(t *testing.T) { testErr.Error(), wantErr) } + bogusCommand := MessageCommand(math.MaxUint8) + // Wire encoded bytes for main and testnet networks magic identifiers. - testnetBytes := makeHeader(Testnet, "", 0, 0) + testnetBytes := makeHeader(Testnet, bogusCommand, 0, 0) // Wire encoded bytes for a message that exceeds max overall message // length. mpl := uint32(MaxMessagePayload) - exceedMaxPayloadBytes := makeHeader(kaspaNet, "getaddr", mpl+1, 0) + exceedMaxPayloadBytes := makeHeader(kaspaNet, CmdAddress, mpl+1, 0) // Wire encoded bytes for a command which is invalid utf-8. - badCommandBytes := makeHeader(kaspaNet, "bogus", 0, 0) + badCommandBytes := makeHeader(kaspaNet, bogusCommand, 0, 0) badCommandBytes[4] = 0x81 // Wire encoded bytes for a command which is valid, but not supported. - unsupportedCommandBytes := makeHeader(kaspaNet, "bogus", 0, 0) + unsupportedCommandBytes := makeHeader(kaspaNet, bogusCommand, 0, 0) // Wire encoded bytes for a message which exceeds the max payload for // a specific message type. - exceedTypePayloadBytes := makeHeader(kaspaNet, "getaddr", 23, 0) + exceedTypePayloadBytes := makeHeader(kaspaNet, CmdGetAddresses, 23, 0) // Wire encoded bytes for a message which does not deliver the full // payload according to the header length. - shortPayloadBytes := makeHeader(kaspaNet, "version", 115, 0) + shortPayloadBytes := makeHeader(kaspaNet, CmdVersion, 115, 0) // Wire encoded bytes for a message with a bad checksum. - badChecksumBytes := makeHeader(kaspaNet, "version", 2, 0xbeef) + badChecksumBytes := makeHeader(kaspaNet, CmdVersion, 2, 0xbeef) badChecksumBytes = append(badChecksumBytes, []byte{0x0, 0x0}...) // Wire encoded bytes for a message which has a valid header, but is @@ -222,12 +225,12 @@ func TestReadMessageWireErrors(t *testing.T) { // contained in the message. Claim there is two, but don't provide // them. At the same time, forge the header fields so the message is // otherwise accurate. - badMessageBytes := makeHeader(kaspaNet, "addr", 1, 0xeaadc31c) + badMessageBytes := makeHeader(kaspaNet, CmdAddress, 1, 0xeaadc31c) badMessageBytes = append(badMessageBytes, 0x2) // Wire encoded bytes for a message which the header claims has 15k // bytes of data to discard. - discardBytes := makeHeader(kaspaNet, "bogus", 15*1024, 0) + discardBytes := makeHeader(kaspaNet, bogusCommand, 15*1024, 0) tests := []struct { buf []byte // Wire encoding @@ -256,7 +259,7 @@ func TestReadMessageWireErrors(t *testing.T) { kaspaNet, len(testnetBytes), &MessageError{}, - 24, + 16, }, // Exceed max overall message payload length. @@ -266,7 +269,7 @@ func TestReadMessageWireErrors(t *testing.T) { kaspaNet, len(exceedMaxPayloadBytes), &MessageError{}, - 24, + 16, }, // Invalid UTF-8 command. @@ -276,7 +279,7 @@ func TestReadMessageWireErrors(t *testing.T) { kaspaNet, len(badCommandBytes), &MessageError{}, - 24, + 16, }, // Valid, but unsupported command. @@ -286,7 +289,7 @@ func TestReadMessageWireErrors(t *testing.T) { kaspaNet, len(unsupportedCommandBytes), &MessageError{}, - 24, + 16, }, // Exceed max allowed payload for a message of a specific type. @@ -296,7 +299,7 @@ func TestReadMessageWireErrors(t *testing.T) { kaspaNet, len(exceedTypePayloadBytes), &MessageError{}, - 24, + 16, }, // Message with a payload shorter than the header indicates. @@ -306,7 +309,7 @@ func TestReadMessageWireErrors(t *testing.T) { kaspaNet, len(shortPayloadBytes), io.EOF, - 24, + 16, }, // Message with a bad checksum. @@ -316,7 +319,7 @@ func TestReadMessageWireErrors(t *testing.T) { kaspaNet, len(badChecksumBytes), &MessageError{}, - 26, + 18, }, // Message with a valid header, but wrong format. @@ -326,7 +329,7 @@ func TestReadMessageWireErrors(t *testing.T) { kaspaNet, len(badMessageBytes), io.EOF, - 25, + 17, }, // 15k bytes of data to discard. @@ -336,7 +339,7 @@ func TestReadMessageWireErrors(t *testing.T) { kaspaNet, len(discardBytes), &MessageError{}, - 24, + 16, }, } @@ -377,9 +380,6 @@ func TestWriteMessageWireErrors(t *testing.T) { kaspaNet := Mainnet wireErr := &MessageError{} - // Fake message with a command that is too long. - badCommandMsg := &fakeMessage{command: "somethingtoolong"} - // Fake message with a problem during encoding encodeErrMsg := &fakeMessage{forceEncodeErr: true} @@ -394,7 +394,7 @@ func TestWriteMessageWireErrors(t *testing.T) { // Fake message that is used to force errors in the header and payload // writes. bogusPayload := []byte{0x01, 0x02, 0x03, 0x04} - bogusMsg := &fakeMessage{command: "bogus", payload: bogusPayload} + bogusMsg := &fakeMessage{command: MessageCommand(math.MaxUint8), payload: bogusPayload} tests := []struct { msg Message // Message to encode @@ -404,8 +404,6 @@ func TestWriteMessageWireErrors(t *testing.T) { err error // Expected error bytes int // Expected num bytes written }{ - // Command too long. - {badCommandMsg, pver, kaspaNet, 0, wireErr, 0}, // Force error in payload encode. {encodeErrMsg, pver, kaspaNet, 0, wireErr, 0}, // Force error due to exceeding max overall message payload size. @@ -415,7 +413,7 @@ func TestWriteMessageWireErrors(t *testing.T) { // Force error in header write. {bogusMsg, pver, kaspaNet, 0, io.ErrShortWrite, 0}, // Force error in payload write. - {bogusMsg, pver, kaspaNet, 24, io.ErrShortWrite, 24}, + {bogusMsg, pver, kaspaNet, 16, io.ErrShortWrite, 16}, } t.Logf("Running %d tests", len(tests)) diff --git a/wire/msgaddresses.go b/wire/msgaddresses.go index c6f97d5d2..7541fb582 100644 --- a/wire/msgaddresses.go +++ b/wire/msgaddresses.go @@ -158,7 +158,7 @@ func (msg *MsgAddresses) KaspaEncode(w io.Writer, pver uint32) error { // Command returns the protocol command string for the message. This is part // of the Message interface implementation. -func (msg *MsgAddresses) Command() string { +func (msg *MsgAddresses) Command() MessageCommand { return CmdAddress } diff --git a/wire/msgaddresses_test.go b/wire/msgaddresses_test.go index f21277d5f..15c2a9624 100644 --- a/wire/msgaddresses_test.go +++ b/wire/msgaddresses_test.go @@ -22,7 +22,7 @@ func TestAddresses(t *testing.T) { pver := ProtocolVersion // Ensure the command is expected value. - wantCmd := "addr" + wantCmd := MessageCommand(3) msg := NewMsgAddresses(false, nil) if cmd := msg.Command(); cmd != wantCmd { t.Errorf("NewMsgAddresses: wrong command - got %v want %v", diff --git a/wire/msgblock.go b/wire/msgblock.go index 592b758e4..71e84c73b 100644 --- a/wire/msgblock.go +++ b/wire/msgblock.go @@ -213,7 +213,7 @@ func (msg *MsgBlock) SerializeSize() int { // Command returns the protocol command string for the message. This is part // of the Message interface implementation. -func (msg *MsgBlock) Command() string { +func (msg *MsgBlock) Command() MessageCommand { return CmdBlock } diff --git a/wire/msgblock_test.go b/wire/msgblock_test.go index 8aaef962d..e0d0b57fa 100644 --- a/wire/msgblock_test.go +++ b/wire/msgblock_test.go @@ -32,7 +32,7 @@ func TestBlock(t *testing.T) { bh := NewBlockHeader(1, parentHashes, hashMerkleRoot, acceptedIDMerkleRoot, utxoCommitment, bits, nonce) // Ensure the command is expected value. - wantCmd := "block" + wantCmd := MessageCommand(8) msg := NewMsgBlock(bh) if cmd := msg.Command(); cmd != wantCmd { t.Errorf("NewMsgBlock: wrong command - got %v want %v", diff --git a/wire/msgblocklocator.go b/wire/msgblocklocator.go index bcf761e76..0ac85a7d7 100644 --- a/wire/msgblocklocator.go +++ b/wire/msgblocklocator.go @@ -89,7 +89,7 @@ func (msg *MsgBlockLocator) KaspaEncode(w io.Writer, pver uint32) error { // Command returns the protocol command string for the message. This is part // of the Message interface implementation. -func (msg *MsgBlockLocator) Command() string { +func (msg *MsgBlockLocator) Command() MessageCommand { return CmdBlockLocator } diff --git a/wire/msgblocklocator_test.go b/wire/msgblocklocator_test.go index 5f50ff35a..b89fec1f7 100644 --- a/wire/msgblocklocator_test.go +++ b/wire/msgblocklocator_test.go @@ -24,7 +24,7 @@ func TestBlockLocator(t *testing.T) { msg := NewMsgBlockLocator() // Ensure the command is expected value. - wantCmd := "locator" + wantCmd := MessageCommand(19) if cmd := msg.Command(); cmd != wantCmd { t.Errorf("NewMsgBlockLocator: wrong command - got %v want %v", cmd, wantCmd) diff --git a/wire/msgfeefilter.go b/wire/msgfeefilter.go index ff56e0836..1585db621 100644 --- a/wire/msgfeefilter.go +++ b/wire/msgfeefilter.go @@ -32,7 +32,7 @@ func (msg *MsgFeeFilter) KaspaEncode(w io.Writer, pver uint32) error { // Command returns the protocol command string for the message. This is part // of the Message interface implementation. -func (msg *MsgFeeFilter) Command() string { +func (msg *MsgFeeFilter) Command() MessageCommand { return CmdFeeFilter } diff --git a/wire/msgfeefilter_test.go b/wire/msgfeefilter_test.go index ff652bc2f..6d419ba85 100644 --- a/wire/msgfeefilter_test.go +++ b/wire/msgfeefilter_test.go @@ -27,7 +27,7 @@ func TestFeeFilterLatest(t *testing.T) { } // Ensure the command is expected value. - wantCmd := "feefilter" + wantCmd := MessageCommand(17) if cmd := msg.Command(); cmd != wantCmd { t.Errorf("NewMsgFeeFilter: wrong command - got %v want %v", cmd, wantCmd) diff --git a/wire/msgfilteradd.go b/wire/msgfilteradd.go index 9bc061008..1eaa189d5 100644 --- a/wire/msgfilteradd.go +++ b/wire/msgfilteradd.go @@ -49,7 +49,7 @@ func (msg *MsgFilterAdd) KaspaEncode(w io.Writer, pver uint32) error { // Command returns the protocol command string for the message. This is part // of the Message interface implementation. -func (msg *MsgFilterAdd) Command() string { +func (msg *MsgFilterAdd) Command() MessageCommand { return CmdFilterAdd } diff --git a/wire/msgfilteradd_test.go b/wire/msgfilteradd_test.go index a0af6fd2c..21395f325 100644 --- a/wire/msgfilteradd_test.go +++ b/wire/msgfilteradd_test.go @@ -21,7 +21,7 @@ func TestFilterAddLatest(t *testing.T) { msg := NewMsgFilterAdd(data) // Ensure the command is expected value. - wantCmd := "filteradd" + wantCmd := MessageCommand(12) if cmd := msg.Command(); cmd != wantCmd { t.Errorf("NewMsgFilterAdd: wrong command - got %v want %v", cmd, wantCmd) diff --git a/wire/msgfilterclear.go b/wire/msgfilterclear.go index 6fe9e40b8..a22ec1010 100644 --- a/wire/msgfilterclear.go +++ b/wire/msgfilterclear.go @@ -29,7 +29,7 @@ func (msg *MsgFilterClear) KaspaEncode(w io.Writer, pver uint32) error { // Command returns the protocol command string for the message. This is part // of the Message interface implementation. -func (msg *MsgFilterClear) Command() string { +func (msg *MsgFilterClear) Command() MessageCommand { return CmdFilterClear } diff --git a/wire/msgfilterclear_test.go b/wire/msgfilterclear_test.go index a8f21a554..995ec152f 100644 --- a/wire/msgfilterclear_test.go +++ b/wire/msgfilterclear_test.go @@ -20,7 +20,7 @@ func TestFilterClearLatest(t *testing.T) { msg := NewMsgFilterClear() // Ensure the command is expected value. - wantCmd := "filterclear" + wantCmd := MessageCommand(13) if cmd := msg.Command(); cmd != wantCmd { t.Errorf("NewMsgFilterClear: wrong command - got %v want %v", cmd, wantCmd) diff --git a/wire/msgfilterload.go b/wire/msgfilterload.go index 9382296e8..f04b8bb7e 100644 --- a/wire/msgfilterload.go +++ b/wire/msgfilterload.go @@ -93,7 +93,7 @@ func (msg *MsgFilterLoad) KaspaEncode(w io.Writer, pver uint32) error { // Command returns the protocol command string for the message. This is part // of the Message interface implementation. -func (msg *MsgFilterLoad) Command() string { +func (msg *MsgFilterLoad) Command() MessageCommand { return CmdFilterLoad } diff --git a/wire/msgfilterload_test.go b/wire/msgfilterload_test.go index 91b810e37..cbce450a6 100644 --- a/wire/msgfilterload_test.go +++ b/wire/msgfilterload_test.go @@ -21,7 +21,7 @@ func TestFilterLoadLatest(t *testing.T) { msg := NewMsgFilterLoad(data, 10, 0, 0) // Ensure the command is expected value. - wantCmd := "filterload" + wantCmd := MessageCommand(14) if cmd := msg.Command(); cmd != wantCmd { t.Errorf("NewMsgFilterLoad: wrong command - got %v want %v", cmd, wantCmd) diff --git a/wire/msggetaddresses.go b/wire/msggetaddresses.go index a71533fd7..dbceec6f0 100644 --- a/wire/msggetaddresses.go +++ b/wire/msggetaddresses.go @@ -83,7 +83,7 @@ func (msg *MsgGetAddresses) KaspaEncode(w io.Writer, pver uint32) error { // Command returns the protocol command string for the message. This is part // of the Message interface implementation. -func (msg *MsgGetAddresses) Command() string { +func (msg *MsgGetAddresses) Command() MessageCommand { return CmdGetAddresses } diff --git a/wire/msggetaddresses_test.go b/wire/msggetaddresses_test.go index 605b15314..23846d282 100644 --- a/wire/msggetaddresses_test.go +++ b/wire/msggetaddresses_test.go @@ -18,7 +18,7 @@ func TestGetAddresses(t *testing.T) { pver := ProtocolVersion // Ensure the command is expected value. - wantCmd := "getaddr" + wantCmd := MessageCommand(2) msg := NewMsgGetAddresses(false, nil) if cmd := msg.Command(); cmd != wantCmd { t.Errorf("NewMsgGetAddresses: wrong command - got %v want %v", diff --git a/wire/msggetblockinvs.go b/wire/msggetblockinvs.go index 5e6856504..7dc2607d0 100644 --- a/wire/msggetblockinvs.go +++ b/wire/msggetblockinvs.go @@ -44,7 +44,7 @@ func (msg *MsgGetBlockInvs) KaspaEncode(w io.Writer, pver uint32) error { // Command returns the protocol command string for the message. This is part // of the Message interface implementation. -func (msg *MsgGetBlockInvs) Command() string { +func (msg *MsgGetBlockInvs) Command() MessageCommand { return CmdGetBlockInvs } diff --git a/wire/msggetblockinvs_test.go b/wire/msggetblockinvs_test.go index 43ea49ad0..64b0f3ed3 100644 --- a/wire/msggetblockinvs_test.go +++ b/wire/msggetblockinvs_test.go @@ -39,7 +39,7 @@ func TestGetBlockInvs(t *testing.T) { } // Ensure the command is expected value. - wantCmd := "getblockinvs" + wantCmd := MessageCommand(4) if cmd := msg.Command(); cmd != wantCmd { t.Errorf("NewMsgGetBlockInvs: wrong command - got %v want %v", cmd, wantCmd) diff --git a/wire/msggetblocklocator.go b/wire/msggetblocklocator.go index 5ef3f48c5..19e7e9640 100644 --- a/wire/msggetblocklocator.go +++ b/wire/msggetblocklocator.go @@ -48,7 +48,7 @@ func (msg *MsgGetBlockLocator) KaspaEncode(w io.Writer, pver uint32) error { // Command returns the protocol command string for the message. This is part // of the Message interface implementation. -func (msg *MsgGetBlockLocator) Command() string { +func (msg *MsgGetBlockLocator) Command() MessageCommand { return CmdGetBlockLocator } diff --git a/wire/msggetblocklocator_test.go b/wire/msggetblocklocator_test.go index 2605341ab..dd4c13384 100644 --- a/wire/msggetblocklocator_test.go +++ b/wire/msggetblocklocator_test.go @@ -22,7 +22,7 @@ func TestGetBlockLocator(t *testing.T) { } // Ensure the command is expected value. - wantCmd := "getlocator" + wantCmd := MessageCommand(18) msg := NewMsgGetBlockLocator(highHash, &daghash.ZeroHash) if cmd := msg.Command(); cmd != wantCmd { t.Errorf("NewMsgGetBlockLocator: wrong command - got %v want %v", diff --git a/wire/msggetdata.go b/wire/msggetdata.go index 480152daa..ab43d2341 100644 --- a/wire/msggetdata.go +++ b/wire/msggetdata.go @@ -92,7 +92,7 @@ func (msg *MsgGetData) KaspaEncode(w io.Writer, pver uint32) error { // Command returns the protocol command string for the message. This is part // of the Message interface implementation. -func (msg *MsgGetData) Command() string { +func (msg *MsgGetData) Command() MessageCommand { return CmdGetData } diff --git a/wire/msggetdata_test.go b/wire/msggetdata_test.go index 842bf3fef..9da50ff99 100644 --- a/wire/msggetdata_test.go +++ b/wire/msggetdata_test.go @@ -20,7 +20,7 @@ func TestGetData(t *testing.T) { pver := ProtocolVersion // Ensure the command is expected value. - wantCmd := "getdata" + wantCmd := MessageCommand(6) msg := NewMsgGetData() if cmd := msg.Command(); cmd != wantCmd { t.Errorf("NewMsgGetData: wrong command - got %v want %v", diff --git a/wire/msggetrelayblocks.go b/wire/msggetrelayblocks.go index 5c979dcb1..4f07aab41 100644 --- a/wire/msggetrelayblocks.go +++ b/wire/msggetrelayblocks.go @@ -30,7 +30,7 @@ func (msg *MsgGetRelayBlocks) KaspaEncode(w io.Writer, pver uint32) error { // Command returns the protocol command string for the message. This is part // of the Message interface implementation. -func (msg *MsgGetRelayBlocks) Command() string { +func (msg *MsgGetRelayBlocks) Command() MessageCommand { return CmdGetRelayBlocks } diff --git a/wire/msggetselectedtip.go b/wire/msggetselectedtip.go index ebe718d89..50016ebd3 100644 --- a/wire/msggetselectedtip.go +++ b/wire/msggetselectedtip.go @@ -24,7 +24,7 @@ func (msg *MsgGetSelectedTip) KaspaEncode(w io.Writer, pver uint32) error { // Command returns the protocol command string for the message. This is part // of the Message interface implementation. -func (msg *MsgGetSelectedTip) Command() string { +func (msg *MsgGetSelectedTip) Command() MessageCommand { return CmdGetSelectedTip } diff --git a/wire/msggetselectedtip_test.go b/wire/msggetselectedtip_test.go index df96fe1b9..bc8e7309d 100644 --- a/wire/msggetselectedtip_test.go +++ b/wire/msggetselectedtip_test.go @@ -17,7 +17,7 @@ func TestGetSelectedTip(t *testing.T) { pver := ProtocolVersion // Ensure the command is expected value. - wantCmd := "getseltip" + wantCmd := MessageCommand(21) msg := NewMsgGetSelectedTip() if cmd := msg.Command(); cmd != wantCmd { t.Errorf("NewMsgGetSelectedTip: wrong command - got %v want %v", diff --git a/wire/msginv.go b/wire/msginv.go index 2b18abf45..32d07e78e 100644 --- a/wire/msginv.go +++ b/wire/msginv.go @@ -100,7 +100,7 @@ func (msg *MsgInv) KaspaEncode(w io.Writer, pver uint32) error { // Command returns the protocol command string for the message. This is part // of the Message interface implementation. -func (msg *MsgInv) Command() string { +func (msg *MsgInv) Command() MessageCommand { return CmdInv } diff --git a/wire/msginv_test.go b/wire/msginv_test.go index cebb93015..69666ca71 100644 --- a/wire/msginv_test.go +++ b/wire/msginv_test.go @@ -20,7 +20,7 @@ func TestInv(t *testing.T) { pver := ProtocolVersion // Ensure the command is expected value. - wantCmd := "inv" + wantCmd := MessageCommand(5) msg := NewMsgInv() if cmd := msg.Command(); cmd != wantCmd { t.Errorf("NewMsgInv: wrong command - got %v want %v", diff --git a/wire/msginvrelayblock.go b/wire/msginvrelayblock.go index 846bf203f..24732f2d1 100644 --- a/wire/msginvrelayblock.go +++ b/wire/msginvrelayblock.go @@ -26,7 +26,7 @@ func (msg *MsgInvRelayBlock) KaspaEncode(w io.Writer, pver uint32) error { // Command returns the protocol command string for the message. This is part // of the Message interface implementation. -func (msg *MsgInvRelayBlock) Command() string { +func (msg *MsgInvRelayBlock) Command() MessageCommand { return CmdInvRelayBlock } diff --git a/wire/msgmerkleblock.go b/wire/msgmerkleblock.go index 303cf961d..33f21b25a 100644 --- a/wire/msgmerkleblock.go +++ b/wire/msgmerkleblock.go @@ -126,7 +126,7 @@ func (msg *MsgMerkleBlock) KaspaEncode(w io.Writer, pver uint32) error { // Command returns the protocol command string for the message. This is part // of the Message interface implementation. -func (msg *MsgMerkleBlock) Command() string { +func (msg *MsgMerkleBlock) Command() MessageCommand { return CmdMerkleBlock } diff --git a/wire/msgmerkleblock_test.go b/wire/msgmerkleblock_test.go index fb7980309..1aadb2923 100644 --- a/wire/msgmerkleblock_test.go +++ b/wire/msgmerkleblock_test.go @@ -30,7 +30,7 @@ func TestMerkleBlock(t *testing.T) { bh := NewBlockHeader(1, parentHashes, hashMerkleRoot, acceptedIDMerkleRoot, utxoCommitment, bits, nonce) // Ensure the command is expected value. - wantCmd := "merkleblock" + wantCmd := MessageCommand(15) msg := NewMsgMerkleBlock(bh) if cmd := msg.Command(); cmd != wantCmd { t.Errorf("NewMsgBlock: wrong command - got %v want %v", diff --git a/wire/msgnotfound.go b/wire/msgnotfound.go index 4e28aa79c..7edd84afa 100644 --- a/wire/msgnotfound.go +++ b/wire/msgnotfound.go @@ -89,7 +89,7 @@ func (msg *MsgNotFound) KaspaEncode(w io.Writer, pver uint32) error { // Command returns the protocol command string for the message. This is part // of the Message interface implementation. -func (msg *MsgNotFound) Command() string { +func (msg *MsgNotFound) Command() MessageCommand { return CmdNotFound } diff --git a/wire/msgnotfound_test.go b/wire/msgnotfound_test.go index 0d7810f5b..332f6124a 100644 --- a/wire/msgnotfound_test.go +++ b/wire/msgnotfound_test.go @@ -20,7 +20,7 @@ func TestNotFound(t *testing.T) { pver := ProtocolVersion // Ensure the command is expected value. - wantCmd := "notfound" + wantCmd := MessageCommand(7) msg := NewMsgNotFound() if cmd := msg.Command(); cmd != wantCmd { t.Errorf("NewMsgNotFound: wrong command - got %v want %v", diff --git a/wire/msgping.go b/wire/msgping.go index c12350e0a..8bb411e59 100644 --- a/wire/msgping.go +++ b/wire/msgping.go @@ -49,7 +49,7 @@ func (msg *MsgPing) KaspaEncode(w io.Writer, pver uint32) error { // Command returns the protocol command string for the message. This is part // of the Message interface implementation. -func (msg *MsgPing) Command() string { +func (msg *MsgPing) Command() MessageCommand { return CmdPing } diff --git a/wire/msgping_test.go b/wire/msgping_test.go index 86ffb7d41..2e5ab89c5 100644 --- a/wire/msgping_test.go +++ b/wire/msgping_test.go @@ -31,7 +31,7 @@ func TestPing(t *testing.T) { } // Ensure the command is expected value. - wantCmd := "ping" + wantCmd := MessageCommand(10) if cmd := msg.Command(); cmd != wantCmd { t.Errorf("NewMsgPing: wrong command - got %v want %v", cmd, wantCmd) diff --git a/wire/msgpong.go b/wire/msgpong.go index 5bd42d41b..c0dddaa4f 100644 --- a/wire/msgpong.go +++ b/wire/msgpong.go @@ -33,7 +33,7 @@ func (msg *MsgPong) KaspaEncode(w io.Writer, pver uint32) error { // Command returns the protocol command string for the message. This is part // of the Message interface implementation. -func (msg *MsgPong) Command() string { +func (msg *MsgPong) Command() MessageCommand { return CmdPong } diff --git a/wire/msgpong_test.go b/wire/msgpong_test.go index 06eb330be..a05d87903 100644 --- a/wire/msgpong_test.go +++ b/wire/msgpong_test.go @@ -30,7 +30,7 @@ func TestPongLatest(t *testing.T) { } // Ensure the command is expected value. - wantCmd := "pong" + wantCmd := MessageCommand(11) if cmd := msg.Command(); cmd != wantCmd { t.Errorf("NewMsgPong: wrong command - got %v want %v", cmd, wantCmd) diff --git a/wire/msgreject.go b/wire/msgreject.go index 267b4d746..a588a245f 100644 --- a/wire/msgreject.go +++ b/wire/msgreject.go @@ -60,7 +60,7 @@ type MsgReject struct { // Cmd is the command for the message which was rejected such as // as CmdBlock or CmdTx. This can be obtained from the Command function // of a Message. - Cmd string + Cmd MessageCommand // RejectCode is a code indicating why the command was rejected. It // is encoded as a uint8 on the wire. @@ -79,11 +79,10 @@ type MsgReject struct { // This is part of the Message interface implementation. func (msg *MsgReject) KaspaDecode(r io.Reader, pver uint32) error { // Command that was rejected. - cmd, err := ReadVarString(r, pver) + err := ReadElement(r, &msg.Cmd) if err != nil { return err } - msg.Cmd = cmd // Code indicating why the command was rejected. err = ReadElement(r, &msg.Code) @@ -116,7 +115,7 @@ func (msg *MsgReject) KaspaDecode(r io.Reader, pver uint32) error { // This is part of the Message interface implementation. func (msg *MsgReject) KaspaEncode(w io.Writer, pver uint32) error { // Command that was rejected. - err := WriteVarString(w, msg.Cmd) + err := WriteElement(w, msg.Cmd) if err != nil { return err } @@ -148,7 +147,7 @@ func (msg *MsgReject) KaspaEncode(w io.Writer, pver uint32) error { // Command returns the protocol command string for the message. This is part // of the Message interface implementation. -func (msg *MsgReject) Command() string { +func (msg *MsgReject) Command() MessageCommand { return CmdReject } @@ -163,7 +162,7 @@ func (msg *MsgReject) MaxPayloadLength(pver uint32) uint32 { // NewMsgReject returns a new kaspa reject message that conforms to the // Message interface. See MsgReject for details. -func NewMsgReject(command string, code RejectCode, reason string) *MsgReject { +func NewMsgReject(command MessageCommand, code RejectCode, reason string) *MsgReject { return &MsgReject{ Cmd: command, Code: code, diff --git a/wire/msgreject_test.go b/wire/msgreject_test.go index c13fe166b..83744e922 100644 --- a/wire/msgreject_test.go +++ b/wire/msgreject_test.go @@ -71,7 +71,7 @@ func TestRejectLatest(t *testing.T) { } // Ensure the command is expected value. - wantCmd := "reject" + wantCmd := MessageCommand(16) if cmd := msg.Command(); cmd != wantCmd { t.Errorf("NewMsgReject: wrong command - got %v want %v", cmd, wantCmd) @@ -150,12 +150,12 @@ func TestRejectWire(t *testing.T) { // Latest protocol version rejected command version (no hash). { MsgReject{ - Cmd: "version", + Cmd: CmdVersion, Code: RejectDuplicate, Reason: "duplicate version", }, []byte{ - 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, // "version" + 0x00, 0x00, 0x00, 0x00, // CmdVersion 0x12, // RejectDuplicate 0x11, 0x64, 0x75, 0x70, 0x6c, 0x69, 0x63, 0x61, 0x74, 0x65, 0x20, 0x76, 0x65, 0x72, 0x73, 0x69, @@ -166,13 +166,13 @@ func TestRejectWire(t *testing.T) { // Latest protocol version rejected command block (has hash). { MsgReject{ - Cmd: "block", + Cmd: CmdBlock, Code: RejectDuplicate, Reason: "duplicate block", Hash: mainnetGenesisHash, }, []byte{ - 0x05, 0x62, 0x6c, 0x6f, 0x63, 0x6b, // "block" + 0x08, 0x00, 0x00, 0x00, // CmdBlock 0x12, // RejectDuplicate 0x0f, 0x64, 0x75, 0x70, 0x6c, 0x69, 0x63, 0x61, 0x74, 0x65, 0x20, 0x62, 0x6c, 0x6f, 0x63, 0x6b, // "duplicate block" @@ -221,10 +221,10 @@ func TestRejectWire(t *testing.T) { func TestRejectWireErrors(t *testing.T) { pver := ProtocolVersion - baseReject := NewMsgReject("block", RejectDuplicate, "duplicate block") + baseReject := NewMsgReject(CmdBlock, RejectDuplicate, "duplicate block") baseReject.Hash = mainnetGenesisHash baseRejectEncoded := []byte{ - 0x05, 0x62, 0x6c, 0x6f, 0x63, 0x6b, // "block" + 0x08, 0x00, 0x00, 0x00, // CmdBlock 0x12, // RejectDuplicate 0x0f, 0x64, 0x75, 0x70, 0x6c, 0x69, 0x63, 0x61, 0x74, 0x65, 0x20, 0x62, 0x6c, 0x6f, 0x63, 0x6b, // "duplicate block" @@ -246,11 +246,11 @@ func TestRejectWireErrors(t *testing.T) { // Force error in reject command. {baseReject, baseRejectEncoded, pver, 0, io.ErrShortWrite, io.EOF}, // Force error in reject code. - {baseReject, baseRejectEncoded, pver, 6, io.ErrShortWrite, io.EOF}, + {baseReject, baseRejectEncoded, pver, 4, io.ErrShortWrite, io.EOF}, // Force error in reject reason. - {baseReject, baseRejectEncoded, pver, 7, io.ErrShortWrite, io.EOF}, + {baseReject, baseRejectEncoded, pver, 5, io.ErrShortWrite, io.EOF}, // Force error in reject hash. - {baseReject, baseRejectEncoded, pver, 23, io.ErrShortWrite, io.EOF}, + {baseReject, baseRejectEncoded, pver, 21, io.ErrShortWrite, io.EOF}, } t.Logf("Running %d tests", len(tests)) diff --git a/wire/msgselectedtip.go b/wire/msgselectedtip.go index 4abe5cdc8..04684856e 100644 --- a/wire/msgselectedtip.go +++ b/wire/msgselectedtip.go @@ -33,7 +33,7 @@ func (msg *MsgSelectedTip) KaspaEncode(w io.Writer, pver uint32) error { // Command returns the protocol command string for the message. This is part // of the Message interface implementation. -func (msg *MsgSelectedTip) Command() string { +func (msg *MsgSelectedTip) Command() MessageCommand { return CmdSelectedTip } diff --git a/wire/msgselectedtip_test.go b/wire/msgselectedtip_test.go index 7a674a829..08150aaab 100644 --- a/wire/msgselectedtip_test.go +++ b/wire/msgselectedtip_test.go @@ -14,7 +14,7 @@ func TestSelectedTip(t *testing.T) { pver := ProtocolVersion // Ensure the command is expected value. - wantCmd := "selectedtip" + wantCmd := MessageCommand(20) msg := NewMsgSelectedTip(&daghash.ZeroHash) if cmd := msg.Command(); cmd != wantCmd { t.Errorf("NewMsgSelectedTip: wrong command - got %v want %v", diff --git a/wire/msgtx.go b/wire/msgtx.go index 783fd6f62..508321e65 100644 --- a/wire/msgtx.go +++ b/wire/msgtx.go @@ -762,7 +762,7 @@ func (msg *MsgTx) serializeSize(encodingFlags txEncoding) int { // Command returns the protocol command string for the message. This is part // of the Message interface implementation. -func (msg *MsgTx) Command() string { +func (msg *MsgTx) Command() MessageCommand { return CmdTx } diff --git a/wire/msgtx_test.go b/wire/msgtx_test.go index d315c885c..98e711990 100644 --- a/wire/msgtx_test.go +++ b/wire/msgtx_test.go @@ -30,7 +30,7 @@ func TestTx(t *testing.T) { } // Ensure the command is expected value. - wantCmd := "tx" + wantCmd := MessageCommand(9) msg := NewNativeMsgTx(1, nil, nil) if cmd := msg.Command(); cmd != wantCmd { t.Errorf("NewMsgAddresses: wrong command - got %v want %v", diff --git a/wire/msgverack.go b/wire/msgverack.go index 34c3ffbef..15f85c8d3 100644 --- a/wire/msgverack.go +++ b/wire/msgverack.go @@ -29,7 +29,7 @@ func (msg *MsgVerAck) KaspaEncode(w io.Writer, pver uint32) error { // Command returns the protocol command string for the message. This is part // of the Message interface implementation. -func (msg *MsgVerAck) Command() string { +func (msg *MsgVerAck) Command() MessageCommand { return CmdVerAck } diff --git a/wire/msgverack_test.go b/wire/msgverack_test.go index 6e4e84a82..f930bc246 100644 --- a/wire/msgverack_test.go +++ b/wire/msgverack_test.go @@ -17,7 +17,7 @@ func TestVerAck(t *testing.T) { pver := ProtocolVersion // Ensure the command is expected value. - wantCmd := "verack" + wantCmd := MessageCommand(1) msg := NewMsgVerAck() if cmd := msg.Command(); cmd != wantCmd { t.Errorf("NewMsgVerAck: wrong command - got %v want %v", diff --git a/wire/msgversion.go b/wire/msgversion.go index b27f74653..fdfb3ba1d 100644 --- a/wire/msgversion.go +++ b/wire/msgversion.go @@ -221,7 +221,7 @@ func (msg *MsgVersion) KaspaEncode(w io.Writer, pver uint32) error { // Command returns the protocol command string for the message. This is part // of the Message interface implementation. -func (msg *MsgVersion) Command() string { +func (msg *MsgVersion) Command() MessageCommand { return CmdVersion } diff --git a/wire/msgversion_test.go b/wire/msgversion_test.go index 01bee0f05..d7ab2fb57 100644 --- a/wire/msgversion_test.go +++ b/wire/msgversion_test.go @@ -83,7 +83,7 @@ func TestVersion(t *testing.T) { } // Ensure the command is expected value. - wantCmd := "version" + wantCmd := MessageCommand(0) if cmd := msg.Command(); cmd != wantCmd { t.Errorf("NewMsgVersion: wrong command - got %v want %v", cmd, wantCmd)