Compare commits

...

2 Commits

Author SHA1 Message Date
Elichai Turkel
1743dc694a Stop using big.Int to compare hashes (#912)
* Add a benchmark for Hash.Cmp()

* Compare hashes directly without going through big.Int

* Add a more thorough test for Hash.Cmp()
2020-09-10 17:32:45 +03:00
stasatdaglabs
8fb30a5895 [NOD-1367] Fix a race condition with notification listeners (#925)
* [NOD-1367] Add an error handler to GRPCClient.

* [NOD-1367] Fix race condition with notification listeners.

* [NOD-1367] Make go vet happy.
2020-09-10 16:04:56 +03:00
5 changed files with 72 additions and 7 deletions

View File

@@ -44,8 +44,12 @@ func (m *Manager) routerInitializer(router *router.Router, netConnection *netada
err := m.handleIncomingMessages(router, incomingRoute)
m.handleError(err, netConnection)
})
notificationListener := m.context.NotificationManager.AddListener(router)
spawn("routerInitializer-handleOutgoingNotifications", func() {
err := m.handleOutgoingNotifications(router)
defer m.context.NotificationManager.RemoveListener(router)
err := m.handleOutgoingNotifications(notificationListener)
m.handleError(err, netConnection)
})
}
@@ -72,9 +76,7 @@ func (m *Manager) handleIncomingMessages(router *router.Router, incomingRoute *r
}
}
func (m *Manager) handleOutgoingNotifications(router *router.Router) error {
notificationListener := m.context.NotificationManager.AddListener(router)
defer m.context.NotificationManager.RemoveListener(router)
func (m *Manager) handleOutgoingNotifications(notificationListener *rpccontext.NotificationListener) error {
for {
err := notificationListener.ProcessNextNotification()
if err != nil {

View File

@@ -12,9 +12,13 @@ import (
"time"
)
// OnErrorHandler defines a handler function for when errors occur
type OnErrorHandler func(err error)
// GRPCClient is a gRPC-based RPC client
type GRPCClient struct {
stream protowire.RPC_MessageStreamClient
stream protowire.RPC_MessageStreamClient
onErrorHandler OnErrorHandler
}
// Connect connects to the RPC server with the given address
@@ -41,6 +45,11 @@ func (c *GRPCClient) Disconnect() error {
return c.stream.CloseSend()
}
// SetOnErrorHandler sets the client's onErrorHandler
func (c *GRPCClient) SetOnErrorHandler(onErrorHandler OnErrorHandler) {
c.onErrorHandler = onErrorHandler
}
// AttachRouter attaches the given router to the client and starts
// sending/receiving messages via it
func (c *GRPCClient) AttachRouter(router *router.Router) {
@@ -101,5 +110,9 @@ func (c *GRPCClient) handleError(err error) {
}
return
}
if c.onErrorHandler != nil {
c.onErrorHandler(err)
return
}
panic(err)
}

View File

@@ -26,7 +26,7 @@ type RPCClient struct {
func NewRPCClient(rpcAddress string) (*RPCClient, error) {
rpcClient, err := grpcclient.Connect(rpcAddress)
if err != nil {
return nil, errors.Wrapf(err, "error connecting to address %s", rpcClient)
return nil, errors.Wrapf(err, "error connecting to address %s", rpcAddress)
}
rpcRouter, err := buildRPCRouter()
if err != nil {

View File

@@ -224,7 +224,16 @@ func HashToBig(hash *Hash) *big.Int {
// +1 if hash > target
//
func (hash *Hash) Cmp(target *Hash) int {
return HashToBig(hash).Cmp(HashToBig(target))
// We compare the hashes backwards because Hash is stored as a little endian byte array.
for i := HashSize - 1; i >= 0; i-- {
switch {
case hash[i] < target[i]:
return -1
case hash[i] > target[i]:
return 1
}
}
return 0
}
// Less returns true iff hash a is less than hash b

View File

@@ -9,6 +9,7 @@ import (
"encoding/hex"
"errors"
"math/big"
"math/rand"
"reflect"
"testing"
)
@@ -425,3 +426,43 @@ func TestSort(t *testing.T) {
}
}
}
func hashFlipBit(hash Hash, bit int) Hash {
word := bit / 8
bit = bit % 8
hash[word] ^= 1 << bit
return hash
}
func TestHash_Cmp(t *testing.T) {
r := rand.New(rand.NewSource(1))
for i := 0; i < 100; i++ {
hash := Hash{}
n, err := r.Read(hash[:])
if err != nil {
t.Fatalf("Failed generating a random hash '%s'", err)
} else if n != len(hash) {
t.Fatalf("Failed generating a random hash, expected reading: %d. instead read: %d.", len(hash), n)
}
hashBig := HashToBig(&hash)
// Iterate bit by bit, flip it and compare.
for bit := 0; bit < HashSize*8; bit++ {
newHash := hashFlipBit(hash, bit)
if hash.Cmp(&newHash) != hashBig.Cmp(HashToBig(&newHash)) {
t.Errorf("Hash.Cmp disagrees with big.Int.Cmp newHash: %s, hash: %s", newHash, hash)
}
}
}
}
func BenchmarkHash_Cmp(b *testing.B) {
hash0, err := NewHashFromStr("3333333333333333333333333333333333333333333333333333333333333333")
if err != nil {
b.Fatal(err)
}
for n := 0; n < b.N; n++ {
hash0.Cmp(hash0)
}
}