biospb 847aafc91f
Fix RPC connections counting (#2026)
* Fix RPC connections counting

* show incomming connections count

* Use the flag RPCMaxClients instead of the const RPCMaxInboundConnections

* Add grpc server name to log message

Co-authored-by: Michael Sutton <mikisiton2@gmail.com>
2022-06-15 22:49:36 +03:00

144 lines
4.0 KiB
Go

package grpcserver
import (
"context"
"fmt"
"github.com/kaspanet/kaspad/infrastructure/network/netadapter/server"
"github.com/kaspanet/kaspad/util/panics"
"github.com/pkg/errors"
"google.golang.org/grpc"
"google.golang.org/grpc/peer"
"net"
"sync"
"time"
)
type gRPCServer struct {
onConnectedHandler server.OnConnectedHandler
listeningAddresses []string
server *grpc.Server
name string
maxInboundConnections int
inboundConnectionCount int
inboundConnectionCountLock *sync.Mutex
}
// newGRPCServer creates a gRPC server
func newGRPCServer(listeningAddresses []string, maxMessageSize int, maxInboundConnections int, name string) *gRPCServer {
log.Debugf("Created new %s GRPC server with maxMessageSize %d and maxInboundConnections %d", name, maxMessageSize, maxInboundConnections)
return &gRPCServer{
server: grpc.NewServer(grpc.MaxRecvMsgSize(maxMessageSize), grpc.MaxSendMsgSize(maxMessageSize)),
listeningAddresses: listeningAddresses,
name: name,
maxInboundConnections: maxInboundConnections,
inboundConnectionCount: 0,
inboundConnectionCountLock: &sync.Mutex{},
}
}
func (s *gRPCServer) Start() error {
if s.onConnectedHandler == nil {
return errors.New("onConnectedHandler is nil")
}
for _, listenAddress := range s.listeningAddresses {
err := s.listenOn(listenAddress)
if err != nil {
return err
}
}
return nil
}
func (s *gRPCServer) listenOn(listenAddr string) error {
listener, err := net.Listen("tcp", listenAddr)
if err != nil {
return errors.Wrapf(err, "%s error listening on %s", s.name, listenAddr)
}
spawn(fmt.Sprintf("%s.gRPCServer.listenOn-Serve", s.name), func() {
err := s.server.Serve(listener)
if err != nil {
panics.Exit(log, fmt.Sprintf("error serving %s on %s: %+v", s.name, listenAddr, err))
}
})
log.Infof("%s Server listening on %s", s.name, listener.Addr())
return nil
}
func (s *gRPCServer) Stop() error {
const stopTimeout = 2 * time.Second
stopChan := make(chan interface{})
spawn("gRPCServer.Stop", func() {
s.server.GracefulStop()
close(stopChan)
})
select {
case <-stopChan:
case <-time.After(stopTimeout):
log.Warnf("Could not gracefully stop %s: timed out after %s", s.name, stopTimeout)
s.server.Stop()
}
return nil
}
// SetOnConnectedHandler sets the peer connected handler
// function for the server
func (s *gRPCServer) SetOnConnectedHandler(onConnectedHandler server.OnConnectedHandler) {
s.onConnectedHandler = onConnectedHandler
}
func (s *gRPCServer) handleInboundConnection(ctx context.Context, stream grpcStream) error {
connectionCount, err := s.incrementInboundConnectionCountAndLimitIfRequired()
if err != nil {
return err
}
defer s.decrementInboundConnectionCount()
peerInfo, ok := peer.FromContext(ctx)
if !ok {
return errors.Errorf("Error getting stream peer info from context")
}
tcpAddress, ok := peerInfo.Addr.(*net.TCPAddr)
if !ok {
return errors.Errorf("non-tcp connections are not supported")
}
connection := newConnection(s, tcpAddress, stream, nil)
err = s.onConnectedHandler(connection)
if err != nil {
return err
}
log.Infof("%s Incoming connection from %s #%d", s.name, peerInfo.Addr, connectionCount)
<-connection.stopChan
return nil
}
func (s *gRPCServer) incrementInboundConnectionCountAndLimitIfRequired() (int, error) {
s.inboundConnectionCountLock.Lock()
defer s.inboundConnectionCountLock.Unlock()
if s.maxInboundConnections > 0 && s.inboundConnectionCount == s.maxInboundConnections {
log.Warnf("Limit of %d %s inbound connections has been exceeded", s.maxInboundConnections, s.name)
return s.inboundConnectionCount, errors.Errorf("limit of %d %s inbound connections has been exceeded", s.maxInboundConnections, s.name)
}
s.inboundConnectionCount++
return s.inboundConnectionCount, nil
}
func (s *gRPCServer) decrementInboundConnectionCount() {
s.inboundConnectionCountLock.Lock()
defer s.inboundConnectionCountLock.Unlock()
s.inboundConnectionCount--
}