From 5c5491e1e44b61373a749307e8780903c36fd077 Mon Sep 17 00:00:00 2001 From: Ori Newman Date: Wed, 15 May 2019 16:07:37 +0300 Subject: [PATCH] [NOD-172] Port ECMH from bchd and fix Remove to preserve commutativity (#292) * [NOD-172] Port EMCH from bchd * [NOD-172] Fix hdkeychain.TestErrors and add btcec.TestRecoverCompact * [NOD-172] Make ECMH immutable * [NOD-172] Fix gofmt errors * [NOD-172] Add TestMultiset_NewMultisetFromDataSlice and fix Point to be immutable * [NOD-172] Fix gofmt errors * [NOD-172] Add test for checking that the Union of a multiset and its inverse is zero --- btcec/ecmh.go | 150 ++++++++++++++++++++ btcec/ecmh_test.go | 213 ++++++++++++++++++++++++++++ btcec/pubkey.go | 11 +- btcec/signature_test.go | 62 ++++++++ util/hdkeychain/extendedkey_test.go | 5 +- 5 files changed, 438 insertions(+), 3 deletions(-) create mode 100644 btcec/ecmh.go create mode 100644 btcec/ecmh_test.go diff --git a/btcec/ecmh.go b/btcec/ecmh.go new file mode 100644 index 000000000..a447f5066 --- /dev/null +++ b/btcec/ecmh.go @@ -0,0 +1,150 @@ +package btcec + +import ( + "crypto/sha256" + "encoding/binary" + "math/big" + + "github.com/daglabs/btcd/util/daghash" +) + +// Multiset tracks the state of a multiset as used to calculate the ECMH +// (elliptic curve multiset hash) hash of an unordered set. The state is +// a point on the curve. New elements are hashed onto a point on the curve +// and then added to the current state. Hence elements can be added in any +// order and we can also remove elements to return to a prior hash. +type Multiset struct { + curve *KoblitzCurve + x *big.Int + y *big.Int +} + +// NewMultiset returns an empty multiset. The hash of an empty set +// is the 32 byte value of zero. +func NewMultiset(curve *KoblitzCurve) *Multiset { + return &Multiset{curve: curve, x: big.NewInt(0), y: big.NewInt(0)} +} + +// NewMultisetFromPoint initializes a new multiset with the given x, y +// coordinate. +func NewMultisetFromPoint(curve *KoblitzCurve, x, y *big.Int) *Multiset { + var copyX, copyY big.Int + if x != nil { + copyX.Set(x) + } + if y != nil { + copyY.Set(y) + } + return &Multiset{curve: curve, x: ©X, y: ©Y} +} + +// NewMultisetFromDataSlice gets a curve and a slice of byte +// slices, creates an empty multiset, hashes each data and +// add it to the multiset, and return the resulting multiset. +func NewMultisetFromDataSlice(curve *KoblitzCurve, datas [][]byte) *Multiset { + ms := NewMultiset(curve) + for _, data := range datas { + x, y := hashToPoint(curve, data) + ms.addPoint(x, y) + } + return ms +} + +func (ms *Multiset) clone() *Multiset { + return NewMultisetFromPoint(ms.curve, ms.x, ms.y) +} + +// Add hashes the data onto the curve and returns +// a multiset with the new resulting point. +func (ms *Multiset) Add(data []byte) *Multiset { + newMs := ms.clone() + x, y := hashToPoint(ms.curve, data) + newMs.addPoint(x, y) + return newMs +} + +func (ms *Multiset) addPoint(x, y *big.Int) { + ms.x, ms.y = ms.curve.Add(ms.x, ms.y, x, y) +} + +// Remove hashes the data onto the curve, subtracts +// the point from the existing multiset, and returns +// a multiset with the new point. This function +// will execute regardless of whether or not the passed +// data was previously added to the set. Hence if you +// remove an element that was never added and also remove +// all the elements that were added, you will not get +// back to the point at infinity (empty set). +func (ms *Multiset) Remove(data []byte) *Multiset { + newMs := ms.clone() + x, y := hashToPoint(ms.curve, data) + newMs.removePoint(x, y) + return newMs +} + +func (ms *Multiset) removePoint(x, y *big.Int) { + y.Neg(y).Mod(y, ms.curve.P) + ms.x, ms.y = ms.curve.Add(ms.x, ms.y, x, y) +} + +// Union will add the point of the passed multiset instance to the point +// of this multiset and will return a multiset with the resulting point. +func (ms *Multiset) Union(otherMultiset *Multiset) *Multiset { + newMs := ms.clone() + otherMsCopy := otherMultiset.clone() + newMs.addPoint(otherMsCopy.x, otherMsCopy.y) + return newMs +} + +// Subtract will remove the point of the passed multiset instance from the point +// of this multiset and will return a multiset with the resulting point. +func (ms *Multiset) Subtract(otherMultiset *Multiset) *Multiset { + newMs := ms.clone() + otherMsCopy := otherMultiset.clone() + newMs.removePoint(otherMsCopy.x, otherMsCopy.y) + return newMs +} + +// Hash serializes and returns the hash of the multiset. The hash of an empty +// set is the 32 byte value of zero. The hash of a non-empty multiset is the +// sha256 hash of the 32 byte x value concatenated with the 32 byte y value. +func (ms *Multiset) Hash() daghash.Hash { + if ms.x.Sign() == 0 && ms.y.Sign() == 0 { + return daghash.Hash{} + } + + hash := sha256.Sum256(append(ms.x.Bytes(), ms.y.Bytes()...)) + return daghash.Hash(hash) +} + +// Point returns a copy of the x and y coordinates of the current multiset state. +func (ms *Multiset) Point() (x *big.Int, y *big.Int) { + var copyX, copyY big.Int + copyX.Set(ms.x) + copyY.Set(ms.y) + return ©X, ©Y +} + +// hashToPoint hashes the passed data into a point on the curve. The x value +// is sha256(n, sha256(data)) where n starts at zero. If the resulting x value +// is not in the field or x^3+7 is not quadratic residue then n is incremented +// and we try again. There is a 50% chance of success for any given iteration. +func hashToPoint(curve *KoblitzCurve, data []byte) (x *big.Int, y *big.Int) { + i := uint64(0) + var err error + h := sha256.Sum256(data) + n := make([]byte, 8) + for { + binary.LittleEndian.PutUint64(n, i) + h2 := sha256.Sum256(append(n, h[:]...)) + + x = new(big.Int).SetBytes(h2[:]) + + y, err = decompressPoint(curve, x, false) + if err == nil && x.Cmp(curve.N) < 0 { + break + } + i++ + } + return x, y +} diff --git a/btcec/ecmh_test.go b/btcec/ecmh_test.go new file mode 100644 index 000000000..4b81fd9f1 --- /dev/null +++ b/btcec/ecmh_test.go @@ -0,0 +1,213 @@ +package btcec + +import ( + "bytes" + "encoding/hex" + "testing" + + "github.com/daglabs/btcd/util/daghash" +) + +var testVectors = []struct { + dataElementHex string + point [2]string + ecmhHash string + cumulativeHash string +}{ + { + "982051fd1e4ba744bbbe680e1fee14677ba1a3c3540bf7b1cdb606e857233e0e00000000010000000100f2052a0100000043410496b538e853519c726a2c91e61ec11600ae1390813a627c66fb8be7947be63c52da7589379515d4e0a604f8141781e62294721166bf621e73a82cbf2342c858eeac", + [2]string{"4f9a5dce69067bf28603e73a7af4c3650b16539b95bad05eee95dfc94d1efe2c", "346d5b777881f2729e7f89b2de4e8e79c7f2f42d1a0b25a8f10becb66e2d0f98"}, + "9378d88aa60cfba3032cb19f27891886e26fc6de1afa340c1787a633591983f8", + "", + }, + { + "d5fdcc541e25de1c7a5addedf24858b8bb665c9f36ef744ee42c316022c90f9b00000000020000000100f2052a010000004341047211a824f55b505228e4c3d5194c1fcfaa15a456abdf37f9b9d97a4040afc073dee6c89064984f03385237d92167c13e236446b417ab79a0fcae412ae3316b77ac", + [2]string{"68cf91eb2388a0287c13d46011c73fb8efb6be89c0867a47feccb2d11c390d2d", "f42ba72b1079d3d941881836f88b5dcd7c207a6a4839f129272c77ebb7194d42"}, + "e2f3dc6f3aa867c50bd41b80aa3bdafcc9e1d13a6292ff8a5da95da123d185ef", + "afaa1f7ba0bd8a789422fdd6968639a4b8575baf7d54342a987073d038fdbafa", + }, + { + "44f672226090d85db9a9f2fbfe5f0f9609b387af7be5b7fbb7a1767c831c9e9900000000030000000100f2052a0100000043410494b9d3e76c5b1629ecf97fff95d7a4bbdac87cc26099ada28066c6ff1eb9191223cd897194a08d0c2726c5747f1db49e8cf90e75dc3e3550ae9b30086f3cd5aaac", + [2]string{"359c6f59859d1d5af8e7081905cb6bb734c010be8680c14b5a89ee315694fc2b", "fb6ba531d4bd83b14c970ad1bec332a8ae9a05706cd5df7fd91a2f2cc32482fe"}, + "ffed6804617a4a33b1037cdd26426e61fde0faa2c0cc045efffa17c00ff4adcf", + "e236a694532be6a4926ab8d5b1ff9cbfe638178e0008b0a8c5e87c3da2cdbc1c", + }, +} + +func TestHashToPoint(t *testing.T) { + for _, test := range testVectors { + data, err := hex.DecodeString(test.dataElementHex) + if err != nil { + t.Fatal(err) + } + x, y := hashToPoint(S256(), data) + if hex.EncodeToString(x.Bytes()) != test.point[0] || hex.EncodeToString(y.Bytes()) != test.point[1] { + t.Fatal("hashToPoint return incorrect point") + } + } +} + +func TestMultiset_Hash(t *testing.T) { + for _, test := range testVectors { + data, err := hex.DecodeString(test.dataElementHex) + if err != nil { + t.Fatal(err) + } + x, y := hashToPoint(S256(), data) + m := NewMultisetFromPoint(S256(), x, y) + if m.Hash().String() != test.ecmhHash { + t.Fatal("Multiset-Hash returned incorrect hash serialization") + } + } + m := NewMultiset(S256()) + emptySet := m.Hash() + zeroHash := daghash.Hash{} + if !bytes.Equal(emptySet[:], zeroHash[:]) { + t.Fatal("Empty set did not return zero hash") + } +} + +func TestMultiset_AddRemove(t *testing.T) { + m := NewMultiset(S256()) + for i, test := range testVectors { + data, err := hex.DecodeString(test.dataElementHex) + if err != nil { + t.Fatal(err) + } + m = m.Add(data) + if test.cumulativeHash != "" && m.Hash().String() != test.cumulativeHash { + t.Fatalf("Test #%d: Multiset-Add returned incorrect hash. Expected %s but got %s", i, test.cumulativeHash, m.Hash()) + } + } + + for i := len(testVectors) - 1; i > 0; i-- { + data, err := hex.DecodeString(testVectors[i].dataElementHex) + if err != nil { + t.Fatal(err) + } + m = m.Remove(data) + if testVectors[i-1].cumulativeHash != "" && m.Hash().String() != testVectors[i-1].cumulativeHash { + t.Fatalf("Test #%d: Multiset-Remove returned incorrect hash. Expected %s but got %s", i, testVectors[i].cumulativeHash, m.Hash()) + } + } +} + +func TestMultiset_UnionSubtract(t *testing.T) { + m1 := NewMultiset(S256()) + zeroHash := m1.Hash().String() + + for _, test := range testVectors { + data, err := hex.DecodeString(test.dataElementHex) + if err != nil { + t.Fatal(err) + } + m1 = m1.Add(data) + } + + m2 := NewMultiset(S256()) + for _, test := range testVectors { + data, err := hex.DecodeString(test.dataElementHex) + if err != nil { + t.Fatal(err) + } + m2 = m2.Remove(data) + } + m1m2Union := m1.Union(m2) + if m1m2Union.Hash().String() != zeroHash { + t.Fatalf("m1m2Union was expected to return to have zero hash, but was %s instead", m1m2Union.Hash()) + } + + m1Inverse := NewMultiset(S256()).Subtract(m1) + m1InverseM1Union := m1.Union(m1Inverse) + if m1InverseM1Union.Hash().String() != zeroHash { + t.Fatalf("m1InverseM1Union was expected to have zero hash, but got %s instead", m1InverseM1Union.Hash()) + } + + m1SubtractM1 := m1.Subtract(m1) + if m1SubtractM1.Hash().String() != zeroHash { + t.Fatalf("m1SubtractM1 was expected to have zero hash, but got %s instead", m1SubtractM1.Hash()) + } +} + +func TestMultiset_Commutativity(t *testing.T) { + m := NewMultiset(S256()) + zeroHash := m.Hash().String() + + // Check that if we subtract values from zero and then re-add them, we return to zero. + for _, test := range testVectors { + data, err := hex.DecodeString(test.dataElementHex) + if err != nil { + t.Fatal(err) + } + m = m.Remove(data) + } + + for _, test := range testVectors { + data, err := hex.DecodeString(test.dataElementHex) + if err != nil { + t.Fatal(err) + } + m = m.Add(data) + } + if m.Hash().String() != zeroHash { + t.Fatalf("m was expected to be zero hash, but was %s instead", m.Hash()) + } + + // Here we first remove an element from an empty multiset, and then add some other + // elements, and then we create a new empty multiset, then we add the same elements + // we added to the previous multiset, and then we remove the same element we remove + // the same element we removed from the previous multiset. According to commutativity + // laws, the result should be the same. + removeIndex := 0 + removeData, err := hex.DecodeString(testVectors[removeIndex].dataElementHex) + if err != nil { + t.Fatal(err) + } + + m1 := NewMultiset(S256()) + m1 = m1.Remove(removeData) + + for i, test := range testVectors { + if i != removeIndex { + data, err := hex.DecodeString(test.dataElementHex) + if err != nil { + t.Fatal(err) + } + m1 = m1.Add(data) + } + } + + m2 := NewMultiset(S256()) + for i, test := range testVectors { + if i != removeIndex { + data, err := hex.DecodeString(test.dataElementHex) + if err != nil { + t.Fatal(err) + } + m2 = m2.Add(data) + } + } + m2 = m2.Remove(removeData) + + if m1.Hash().String() != m2.Hash().String() { + t.Fatalf("m1 and m2 was exepcted to have the same hash, but got instead m1 %s and m2 %s", m1.Hash(), m2.Hash()) + } +} + +func TestMultiset_NewMultisetFromDataSlice(t *testing.T) { + m1 := NewMultiset(S256()) + datas := make([][]byte, 0, len(testVectors)) + for _, test := range testVectors { + data, err := hex.DecodeString(test.dataElementHex) + if err != nil { + t.Fatal(err) + } + datas = append(datas, data) + m1 = m1.Add(data) + } + + m2 := NewMultisetFromDataSlice(S256(), datas) + if m1.Hash().String() != m2.Hash().String() { + t.Fatalf("m1 and m2 was exepcted to have the same hash, but got instead m1 %s and m2 %s", m1.Hash(), m2.Hash()) + } +} diff --git a/btcec/pubkey.go b/btcec/pubkey.go index b74917718..9d577a9b0 100644 --- a/btcec/pubkey.go +++ b/btcec/pubkey.go @@ -32,13 +32,22 @@ func decompressPoint(curve *KoblitzCurve, x *big.Int, ybit bool) (*big.Int, erro x3 := new(big.Int).Mul(x, x) x3.Mul(x3, x) x3.Add(x3, curve.Params().B) + x3.Mod(x3, curve.Params().P) - // now calculate sqrt mod p of x2 + B + // Now calculate sqrt mod p of x^3 + B // This code used to do a full sqrt based on tonelli/shanks, // but this was replaced by the algorithms referenced in // https://bitcointalk.org/index.php?topic=162805.msg1712294#msg1712294 y := new(big.Int).Exp(x3, curve.QPlus1Div4(), curve.Params().P) + // Check that y is a square root of x^3 + B. + y2 := new(big.Int).Mul(y, y) + y2.Mod(y2, curve.Params().P) + if y2.Cmp(x3) != 0 { + return nil, fmt.Errorf("invalid square root") + } + + // Verify that y-coord has expected parity. if ybit != isOdd(y) { y.Sub(curve.Params().P, y) } diff --git a/btcec/signature_test.go b/btcec/signature_test.go index 69c8d79f0..00583d079 100644 --- a/btcec/signature_test.go +++ b/btcec/signature_test.go @@ -11,6 +11,7 @@ import ( "encoding/hex" "fmt" "math/big" + "reflect" "testing" ) @@ -535,6 +536,67 @@ func TestSignCompact(t *testing.T) { } } +// recoveryTests assert basic tests for public key recovery from signatures. +// The cases are borrowed from github.com/fjl/btcec-issue. +var recoveryTests = []struct { + msg string + sig string + pub string + err error +}{ + { + // Valid curve point recovered. + msg: "ce0677bb30baa8cf067c88db9811f4333d131bf8bcf12fe7065d211dce971008", + sig: "0190f27b8b488db00b00606796d2987f6a5f59ae62ea05effe84fef5b8b0e549984a691139ad57a3f0b906637673aa2f63d1f55cb1a69199d4009eea23ceaddc93", + pub: "04E32DF42865E97135ACFB65F3BAE71BDC86F4D49150AD6A440B6F15878109880A0A2B2667F7E725CEEA70C673093BF67663E0312623C8E091B13CF2C0F11EF652", + }, + { + // Invalid curve point recovered. + msg: "00c547e4f7b0f325ad1e56f57e26c745b09a3e503d86e00e5255ff7f715d3d1c", + sig: "0100b1693892219d736caba55bdb67216e485557ea6b6af75f37096c9aa6a5a75f00b940b1d03b21e36b0e47e79769f095fe2ab855bd91e3a38756b7d75a9c4549", + err: fmt.Errorf("invalid square root"), + }, + { + // Low R and S values. + msg: "ba09edc1275a285fb27bfe82c4eea240a907a0dbaf9e55764b8f318c37d5974f", + sig: "00000000000000000000000000000000000000000000000000000000000000002c0000000000000000000000000000000000000000000000000000000000000004", + pub: "04A7640409AA2083FDAD38B2D8DE1263B2251799591D840653FB02DBBA503D7745FCB83D80E08A1E02896BE691EA6AFFB8A35939A646F1FC79052A744B1C82EDC3", + }, +} + +func TestRecoverCompact(t *testing.T) { + for i, test := range recoveryTests { + msg := decodeHex(test.msg) + sig := decodeHex(test.sig) + + // Magic DER constant. + sig[0] += 27 + + pub, _, err := RecoverCompact(S256(), sig, msg) + + // Verify that returned error matches as expected. + if !reflect.DeepEqual(test.err, err) { + t.Errorf("unexpected error returned from pubkey "+ + "recovery #%d: wanted %v, got %v", + i, test.err, err) + continue + } + + // If check succeeded because a proper error was returned, we + // ignore the returned pubkey. + if err != nil { + continue + } + + // Otherwise, ensure the correct public key was recovered. + exPub, _ := ParsePubKey(decodeHex(test.pub), S256()) + if !exPub.IsEqual(pub) { + t.Errorf("unexpected recovered public key #%d: "+ + "want %v, got %v", i, exPub, pub) + } + } +} + func TestRFC6979(t *testing.T) { // Test vectors matching Trezor and CoreBitcoin implementations. // - https://github.com/trezor/trezor-crypto/blob/9fea8f8ab377dc514e40c6fd1f7c89a74c1d8dc6/tests.c#L432-L453 diff --git a/util/hdkeychain/extendedkey_test.go b/util/hdkeychain/extendedkey_test.go index 50e7a638c..81ea228aa 100644 --- a/util/hdkeychain/extendedkey_test.go +++ b/util/hdkeychain/extendedkey_test.go @@ -12,10 +12,11 @@ import ( "bytes" "encoding/hex" "errors" - "github.com/daglabs/btcd/util" "math" "reflect" "testing" + + "github.com/daglabs/btcd/util" ) // TestBIP0032Vectors tests the vectors provided by [BIP32] to ensure the @@ -856,7 +857,7 @@ func TestErrors(t *testing.T) { { name: "pubkey not on curve", key: "xpub661MyMwAqRbcFtXgS5sYJABqqG9YLmC4Q1Rdap9gSE8NqtwybGhePY2gZ1hr9Rwbk95YadvBkQXxzHBSngB8ndpW6QH7zhhsXZ2jHyZqPjk", - err: errors.New("pubkey isn't on secp256k1 curve"), + err: errors.New("invalid square root"), }, { name: "unsupported version",