diff --git a/infrastructure/network/rpcclient/grpcclient/grpcclient.go b/infrastructure/network/rpcclient/grpcclient/grpcclient.go index e15731aa3..4042c6267 100644 --- a/infrastructure/network/rpcclient/grpcclient/grpcclient.go +++ b/infrastructure/network/rpcclient/grpcclient/grpcclient.go @@ -22,6 +22,7 @@ type OnDisconnectedHandler func() // GRPCClient is a gRPC-based RPC client type GRPCClient struct { stream protowire.RPC_MessageStreamClient + connection *grpc.ClientConn onErrorHandler OnErrorHandler onDisconnectedHandler OnDisconnectedHandler } @@ -43,7 +44,12 @@ func Connect(address string) (*GRPCClient, error) { if err != nil { return nil, errors.Wrapf(err, "error getting client stream for %s", address) } - return &GRPCClient{stream: stream}, nil + return &GRPCClient{stream: stream, connection: gRPCConnection}, nil +} + +// Close closes the underlying grpc connection +func (c *GRPCClient) Close() error { + return c.connection.Close() } // Disconnect disconnects from the RPC server diff --git a/infrastructure/network/rpcclient/rpcclient.go b/infrastructure/network/rpcclient/rpcclient.go index 7256f6c82..e4671c028 100644 --- a/infrastructure/network/rpcclient/rpcclient.go +++ b/infrastructure/network/rpcclient/rpcclient.go @@ -143,6 +143,9 @@ func (c *RPCClient) handleClientDisconnected() { } func (c *RPCClient) handleClientError(err error) { + if atomic.LoadUint32(&c.isClosed) == 1 { + return + } log.Warnf("Received error from client: %s", err) c.handleClientDisconnected() } @@ -159,7 +162,7 @@ func (c *RPCClient) Close() error { return errors.Errorf("Cannot close a client that had already been closed") } c.rpcRouter.router.Close() - return nil + return c.GRPCClient.Close() } // Address returns the address the RPC client connected to