diff --git a/domain/consensus/processes/pruningmanager/pruningmanager.go b/domain/consensus/processes/pruningmanager/pruningmanager.go index dd5798afd..e0b7ca4c3 100644 --- a/domain/consensus/processes/pruningmanager/pruningmanager.go +++ b/domain/consensus/processes/pruningmanager/pruningmanager.go @@ -772,78 +772,41 @@ func (pm *pruningManager) calculateDiffBetweenPreviousAndCurrentPruningPoints(st if err != nil { return nil, err } - currentPruningGhostDAG, err := pm.ghostdagDataStore.Get(pm.databaseContext, stagingArea, currentPruningHash, false) + + utxoDiff := utxo.NewMutableUTXODiff() + + iterator, err := pm.dagTraversalManager.SelectedChildIterator(stagingArea, currentPruningHash, previousPruningHash, false) if err != nil { return nil, err } - previousPruningGhostDAG, err := pm.ghostdagDataStore.Get(pm.databaseContext, stagingArea, previousPruningHash, false) - if err != nil { - return nil, err + defer iterator.Close() + + for ok := iterator.First(); ok; ok = iterator.Next() { + child, err := iterator.Get() + if err != nil { + return nil, err + } + chainBlockAcceptanceData, err := pm.acceptanceDataStore.Get(pm.databaseContext, stagingArea, child) + if err != nil { + return nil, err + } + chainBlockHeader, err := pm.blockHeaderStore.BlockHeader(pm.databaseContext, stagingArea, child) + if err != nil { + return nil, err + } + for _, blockAcceptanceData := range chainBlockAcceptanceData { + for _, transactionAcceptanceData := range blockAcceptanceData.TransactionAcceptanceData { + if transactionAcceptanceData.IsAccepted { + err = utxoDiff.AddTransaction(transactionAcceptanceData.Transaction, chainBlockHeader.DAAScore()) + if err != nil { + return nil, err + } + } + } + } } - currentPruningCurrentDiffChild := currentPruningHash - previousPruningCurrentDiffChild := previousPruningHash - // We need to use BlueWork because it's the only thing that's monotonic in the whole DAG - // We use the BlueWork to know which point is currently lower on the DAG so we can keep climbing its children, - // that way we keep climbing on the lowest point until they both reach the exact same descendant - currentPruningCurrentDiffChildBlueWork := currentPruningGhostDAG.BlueWork() - previousPruningCurrentDiffChildBlueWork := previousPruningGhostDAG.BlueWork() - - var diffHashesFromPrevious []*externalapi.DomainHash - var diffHashesFromCurrent []*externalapi.DomainHash - for { - // if currentPruningCurrentDiffChildBlueWork > previousPruningCurrentDiffChildBlueWork - if currentPruningCurrentDiffChildBlueWork.Cmp(previousPruningCurrentDiffChildBlueWork) == 1 { - diffHashesFromPrevious = append(diffHashesFromPrevious, previousPruningCurrentDiffChild) - previousPruningCurrentDiffChild, err = pm.utxoDiffStore.UTXODiffChild(pm.databaseContext, stagingArea, previousPruningCurrentDiffChild) - if err != nil { - return nil, err - } - diffChildGhostDag, err := pm.ghostdagDataStore.Get(pm.databaseContext, stagingArea, previousPruningCurrentDiffChild, false) - if err != nil { - return nil, err - } - previousPruningCurrentDiffChildBlueWork = diffChildGhostDag.BlueWork() - } else if currentPruningCurrentDiffChild.Equal(previousPruningCurrentDiffChild) { - break - } else { - diffHashesFromCurrent = append(diffHashesFromCurrent, currentPruningCurrentDiffChild) - currentPruningCurrentDiffChild, err = pm.utxoDiffStore.UTXODiffChild(pm.databaseContext, stagingArea, currentPruningCurrentDiffChild) - if err != nil { - return nil, err - } - diffChildGhostDag, err := pm.ghostdagDataStore.Get(pm.databaseContext, stagingArea, currentPruningCurrentDiffChild, false) - if err != nil { - return nil, err - } - currentPruningCurrentDiffChildBlueWork = diffChildGhostDag.BlueWork() - } - } - // The order in which we apply the diffs should be from top to bottom, but we traversed from bottom to top - // so we apply the diffs in reverse order. - oldDiff := utxo.NewMutableUTXODiff() - for i := len(diffHashesFromPrevious) - 1; i >= 0; i-- { - utxoDiff, err := pm.utxoDiffStore.UTXODiff(pm.databaseContext, stagingArea, diffHashesFromPrevious[i]) - if err != nil { - return nil, err - } - err = oldDiff.WithDiffInPlace(utxoDiff) - if err != nil { - return nil, err - } - } - newDiff := utxo.NewMutableUTXODiff() - for i := len(diffHashesFromCurrent) - 1; i >= 0; i-- { - utxoDiff, err := pm.utxoDiffStore.UTXODiff(pm.databaseContext, stagingArea, diffHashesFromCurrent[i]) - if err != nil { - return nil, err - } - err = newDiff.WithDiffInPlace(utxoDiff) - if err != nil { - return nil, err - } - } - return oldDiff.DiffFrom(newDiff.ToImmutable()) + return utxoDiff.ToImmutable(), err } // finalityScore is the number of finality intervals passed since diff --git a/infrastructure/network/rpcclient/grpcclient/grpcclient.go b/infrastructure/network/rpcclient/grpcclient/grpcclient.go index e15731aa3..4042c6267 100644 --- a/infrastructure/network/rpcclient/grpcclient/grpcclient.go +++ b/infrastructure/network/rpcclient/grpcclient/grpcclient.go @@ -22,6 +22,7 @@ type OnDisconnectedHandler func() // GRPCClient is a gRPC-based RPC client type GRPCClient struct { stream protowire.RPC_MessageStreamClient + connection *grpc.ClientConn onErrorHandler OnErrorHandler onDisconnectedHandler OnDisconnectedHandler } @@ -43,7 +44,12 @@ func Connect(address string) (*GRPCClient, error) { if err != nil { return nil, errors.Wrapf(err, "error getting client stream for %s", address) } - return &GRPCClient{stream: stream}, nil + return &GRPCClient{stream: stream, connection: gRPCConnection}, nil +} + +// Close closes the underlying grpc connection +func (c *GRPCClient) Close() error { + return c.connection.Close() } // Disconnect disconnects from the RPC server diff --git a/infrastructure/network/rpcclient/rpcclient.go b/infrastructure/network/rpcclient/rpcclient.go index 7256f6c82..e4671c028 100644 --- a/infrastructure/network/rpcclient/rpcclient.go +++ b/infrastructure/network/rpcclient/rpcclient.go @@ -143,6 +143,9 @@ func (c *RPCClient) handleClientDisconnected() { } func (c *RPCClient) handleClientError(err error) { + if atomic.LoadUint32(&c.isClosed) == 1 { + return + } log.Warnf("Received error from client: %s", err) c.handleClientDisconnected() } @@ -159,7 +162,7 @@ func (c *RPCClient) Close() error { return errors.Errorf("Cannot close a client that had already been closed") } c.rpcRouter.router.Close() - return nil + return c.GRPCClient.Close() } // Address returns the address the RPC client connected to diff --git a/testing/integration/rpc_test.go b/testing/integration/rpc_test.go index 44fb1e40d..76caf1ff5 100644 --- a/testing/integration/rpc_test.go +++ b/testing/integration/rpc_test.go @@ -2,6 +2,7 @@ package integration import ( "github.com/kaspanet/kaspad/infrastructure/config" + "runtime" "testing" "time" @@ -26,6 +27,37 @@ func newTestRPCClient(rpcAddress string) (*testRPCClient, error) { }, nil } +func connectAndClose(rpcAddress string) error { + client, err := rpcclient.NewRPCClient(rpcAddress) + if err != nil { + return err + } + defer client.Close() + return nil +} + +func TestRPCClientGoroutineLeak(t *testing.T) { + _, teardown := setupHarness(t, &harnessParams{ + p2pAddress: p2pAddress1, + rpcAddress: rpcAddress1, + miningAddress: miningAddress1, + miningAddressPrivateKey: miningAddress1PrivateKey, + }) + defer teardown() + numGoroutinesBefore := runtime.NumGoroutine() + for i := 1; i < 100; i++ { + err := connectAndClose(rpcAddress1) + if err != nil { + t.Fatalf("Failed to set up an RPC client: %s", err) + } + time.Sleep(10 * time.Millisecond) + if runtime.NumGoroutine() > numGoroutinesBefore+10 { + t.Fatalf("Number of goroutines is increasing for each RPC client open (%d -> %d), which indicates a memory leak", + numGoroutinesBefore, runtime.NumGoroutine()) + } + } +} + func TestRPCMaxInboundConnections(t *testing.T) { harness, teardown := setupHarness(t, &harnessParams{ p2pAddress: p2pAddress1,