diff --git a/app/protocol/flows/ping/send.go b/app/protocol/flows/ping/send.go index e46c1672e..c1b0e4f46 100644 --- a/app/protocol/flows/ping/send.go +++ b/app/protocol/flows/ping/send.go @@ -1,10 +1,10 @@ package ping import ( + "github.com/kaspanet/kaspad/app/protocol/common" "time" "github.com/kaspanet/kaspad/app/appmessage" - "github.com/kaspanet/kaspad/app/protocol/common" peerpkg "github.com/kaspanet/kaspad/app/protocol/peer" "github.com/kaspanet/kaspad/app/protocol/protocolerrors" "github.com/kaspanet/kaspad/infrastructure/network/netadapter/router" diff --git a/infrastructure/network/netadapter/server/grpcserver/grpc_connection.go b/infrastructure/network/netadapter/server/grpcserver/grpc_connection.go index 38d042c8b..74f311cc9 100644 --- a/infrastructure/network/netadapter/server/grpcserver/grpc_connection.go +++ b/infrastructure/network/netadapter/server/grpcserver/grpc_connection.go @@ -13,11 +13,11 @@ import ( ) type gRPCConnection struct { - server *gRPCServer - address *net.TCPAddr - isOutbound bool - stream grpcStream - router *router.Router + server *gRPCServer + address *net.TCPAddr + stream grpcStream + router *router.Router + lowLevelClientConnection *grpc.ClientConn // streamLock protects concurrent access to stream. // Note that it's an RWMutex. Despite what the name @@ -34,14 +34,16 @@ type gRPCConnection struct { isConnected uint32 } -func newConnection(server *gRPCServer, address *net.TCPAddr, isOutbound bool, stream grpcStream) *gRPCConnection { +func newConnection(server *gRPCServer, address *net.TCPAddr, stream grpcStream, + lowLevelClientConnection *grpc.ClientConn) *gRPCConnection { + connection := &gRPCConnection{ - server: server, - address: address, - isOutbound: isOutbound, - stream: stream, - stopChan: make(chan struct{}), - isConnected: 1, + server: server, + address: address, + stream: stream, + stopChan: make(chan struct{}), + isConnected: 1, + lowLevelClientConnection: lowLevelClientConnection, } return connection @@ -83,7 +85,7 @@ func (c *gRPCConnection) SetOnInvalidMessageHandler(onInvalidMessageHandler serv } func (c *gRPCConnection) IsOutbound() bool { - return c.isOutbound + return c.lowLevelClientConnection != nil } // Disconnect disconnects the connection @@ -98,7 +100,7 @@ func (c *gRPCConnection) Disconnect() { close(c.stopChan) - if c.isOutbound { + if c.IsOutbound() { c.closeSend() log.Debugf("Disconnected from %s", c) } @@ -138,5 +140,8 @@ func (c *gRPCConnection) closeSend() { defer c.streamLock.Unlock() clientStream := c.stream.(protowire.P2P_MessageStreamClient) - _ = clientStream.CloseSend() // ignore error because we don't really know what's the status of the connection + + // ignore error because we don't really know what's the status of the connection + _ = clientStream.CloseSend() + _ = c.lowLevelClientConnection.Close() } diff --git a/infrastructure/network/netadapter/server/grpcserver/grpc_server.go b/infrastructure/network/netadapter/server/grpcserver/grpc_server.go index 7a74d5cc6..ad82eb376 100644 --- a/infrastructure/network/netadapter/server/grpcserver/grpc_server.go +++ b/infrastructure/network/netadapter/server/grpcserver/grpc_server.go @@ -90,12 +90,12 @@ func (s *gRPCServer) Connect(address string) (server.Connection, error) { ctx, cancel := context.WithTimeout(context.Background(), dialTimeout) defer cancel() - gRPCConnection, err := grpc.DialContext(ctx, address, grpc.WithInsecure(), grpc.WithBlock()) + gRPCClientConnection, err := grpc.DialContext(ctx, address, grpc.WithInsecure(), grpc.WithBlock()) if err != nil { return nil, errors.Wrapf(err, "error connecting to %s", address) } - client := protowire.NewP2PClient(gRPCConnection) + client := protowire.NewP2PClient(gRPCClientConnection) stream, err := client.MessageStream(context.Background(), grpc.UseCompressor(gzip.Name)) if err != nil { return nil, errors.Wrapf(err, "error getting client stream for %s", address) @@ -110,7 +110,7 @@ func (s *gRPCServer) Connect(address string) (server.Connection, error) { return nil, errors.Errorf("non-tcp addresses are not supported") } - connection := newConnection(s, tcpAddress, true, stream) + connection := newConnection(s, tcpAddress, stream, gRPCClientConnection) err = s.onConnectedHandler(connection) if err != nil { diff --git a/infrastructure/network/netadapter/server/grpcserver/p2pserver.go b/infrastructure/network/netadapter/server/grpcserver/p2pserver.go index d0cea621a..92c28e06e 100644 --- a/infrastructure/network/netadapter/server/grpcserver/p2pserver.go +++ b/infrastructure/network/netadapter/server/grpcserver/p2pserver.go @@ -29,7 +29,7 @@ func (p *p2pServer) MessageStream(stream protowire.P2P_MessageStreamServer) erro return errors.Errorf("non-tcp connections are not supported") } - connection := newConnection(p.server, tcpAddress, false, stream) + connection := newConnection(p.server, tcpAddress, stream, nil) err := p.server.onConnectedHandler(connection) if err != nil {