diff --git a/infrastructure/network/rpcclient/grpcclient/grpcclient.go b/infrastructure/network/rpcclient/grpcclient/grpcclient.go index e15731aa3..4042c6267 100644 --- a/infrastructure/network/rpcclient/grpcclient/grpcclient.go +++ b/infrastructure/network/rpcclient/grpcclient/grpcclient.go @@ -22,6 +22,7 @@ type OnDisconnectedHandler func() // GRPCClient is a gRPC-based RPC client type GRPCClient struct { stream protowire.RPC_MessageStreamClient + connection *grpc.ClientConn onErrorHandler OnErrorHandler onDisconnectedHandler OnDisconnectedHandler } @@ -43,7 +44,12 @@ func Connect(address string) (*GRPCClient, error) { if err != nil { return nil, errors.Wrapf(err, "error getting client stream for %s", address) } - return &GRPCClient{stream: stream}, nil + return &GRPCClient{stream: stream, connection: gRPCConnection}, nil +} + +// Close closes the underlying grpc connection +func (c *GRPCClient) Close() error { + return c.connection.Close() } // Disconnect disconnects from the RPC server diff --git a/infrastructure/network/rpcclient/rpcclient.go b/infrastructure/network/rpcclient/rpcclient.go index 7256f6c82..e4671c028 100644 --- a/infrastructure/network/rpcclient/rpcclient.go +++ b/infrastructure/network/rpcclient/rpcclient.go @@ -143,6 +143,9 @@ func (c *RPCClient) handleClientDisconnected() { } func (c *RPCClient) handleClientError(err error) { + if atomic.LoadUint32(&c.isClosed) == 1 { + return + } log.Warnf("Received error from client: %s", err) c.handleClientDisconnected() } @@ -159,7 +162,7 @@ func (c *RPCClient) Close() error { return errors.Errorf("Cannot close a client that had already been closed") } c.rpcRouter.router.Close() - return nil + return c.GRPCClient.Close() } // Address returns the address the RPC client connected to diff --git a/testing/integration/rpc_test.go b/testing/integration/rpc_test.go index 44fb1e40d..76caf1ff5 100644 --- a/testing/integration/rpc_test.go +++ b/testing/integration/rpc_test.go @@ -2,6 +2,7 @@ package integration import ( "github.com/kaspanet/kaspad/infrastructure/config" + "runtime" "testing" "time" @@ -26,6 +27,37 @@ func newTestRPCClient(rpcAddress string) (*testRPCClient, error) { }, nil } +func connectAndClose(rpcAddress string) error { + client, err := rpcclient.NewRPCClient(rpcAddress) + if err != nil { + return err + } + defer client.Close() + return nil +} + +func TestRPCClientGoroutineLeak(t *testing.T) { + _, teardown := setupHarness(t, &harnessParams{ + p2pAddress: p2pAddress1, + rpcAddress: rpcAddress1, + miningAddress: miningAddress1, + miningAddressPrivateKey: miningAddress1PrivateKey, + }) + defer teardown() + numGoroutinesBefore := runtime.NumGoroutine() + for i := 1; i < 100; i++ { + err := connectAndClose(rpcAddress1) + if err != nil { + t.Fatalf("Failed to set up an RPC client: %s", err) + } + time.Sleep(10 * time.Millisecond) + if runtime.NumGoroutine() > numGoroutinesBefore+10 { + t.Fatalf("Number of goroutines is increasing for each RPC client open (%d -> %d), which indicates a memory leak", + numGoroutinesBefore, runtime.NumGoroutine()) + } + } +} + func TestRPCMaxInboundConnections(t *testing.T) { harness, teardown := setupHarness(t, &harnessParams{ p2pAddress: p2pAddress1,