From 11869905ae77484b5ddc099fe3a370f287bdd37e Mon Sep 17 00:00:00 2001 From: Anthony Romano Date: Fri, 10 Jun 2016 13:29:51 -0700 Subject: [PATCH] bridge: packet corruption and reordering With bonus bridge connection code refactor. --- tools/local-tester/bridge/bridge.go | 205 +++++++++++++++++--------- tools/local-tester/bridge/dispatch.go | 140 ++++++++++++++++++ 2 files changed, 273 insertions(+), 72 deletions(-) create mode 100644 tools/local-tester/bridge/dispatch.go diff --git a/tools/local-tester/bridge/bridge.go b/tools/local-tester/bridge/bridge.go index b636d3237..86c90857b 100644 --- a/tools/local-tester/bridge/bridge.go +++ b/tools/local-tester/bridge/bridge.go @@ -17,6 +17,7 @@ package main import ( "flag" + "fmt" "io" "io/ioutil" "log" @@ -26,78 +27,129 @@ import ( "time" ) -func bridge(conn net.Conn, remoteAddr string) { - outconn, err := net.Dial("tcp", flag.Args()[1]) +type bridgeConn struct { + in net.Conn + out net.Conn + d dispatcher +} + +func newBridgeConn(in net.Conn, d dispatcher) (*bridgeConn, error) { + out, err := net.Dial("tcp", flag.Args()[1]) if err != nil { - log.Println("oops:", err) - return + in.Close() + return nil, err } - log.Printf("bridging %v <-> %v\n", outconn.LocalAddr(), outconn.RemoteAddr()) - go io.Copy(conn, outconn) - io.Copy(outconn, conn) + return &bridgeConn{in, out, d}, nil } -func blackhole(conn net.Conn) { - log.Printf("blackholing connection %v <-> %v\n", conn.LocalAddr(), conn.RemoteAddr()) - io.Copy(ioutil.Discard, conn) - conn.Close() +func (b *bridgeConn) String() string { + return fmt.Sprintf("%v <-> %v", b.in.RemoteAddr(), b.out.RemoteAddr()) } -func readRemoteOnly(conn net.Conn, remoteAddr string) { - outconn, err := net.Dial("tcp", flag.Args()[1]) - if err != nil { - log.Println("oops:", err) - return - } - log.Printf("one way %v <- %v\n", outconn.LocalAddr(), outconn.RemoteAddr()) - io.Copy(conn, outconn) +func (b *bridgeConn) Close() { + b.in.Close() + b.out.Close() } -func writeRemoteOnly(conn net.Conn, remoteAddr string) { - outconn, err := net.Dial("tcp", flag.Args()[1]) - if err != nil { - log.Println("oops:", err) - return - } - log.Printf("one way %v -> %v\n", outconn.LocalAddr(), outconn.RemoteAddr()) - io.Copy(outconn, conn) +func bridge(b *bridgeConn) { + log.Println("bridging", b.String()) + go b.d.Copy(b.out, makeFetch(b.in)) + b.d.Copy(b.in, makeFetch(b.out)) } -func randCopy(conn net.Conn, outconn net.Conn) { - for rand.Intn(10) > 0 { +func timeBridge(b *bridgeConn) { + go func() { + t := time.Duration(rand.Intn(5)+1) * time.Second + time.Sleep(t) + log.Printf("killing connection %s after %v\n", b.String(), t) + b.Close() + }() + bridge(b) +} + +func blackhole(b *bridgeConn) { + log.Println("blackholing connection", b.String()) + io.Copy(ioutil.Discard, b.in) + b.Close() +} + +func readRemoteOnly(b *bridgeConn) { + log.Println("one way (<-)", b.String()) + b.d.Copy(b.in, makeFetch(b.out)) +} + +func writeRemoteOnly(b *bridgeConn) { + log.Println("one way (->)", b.String()) + b.d.Copy(b.out, makeFetch(b.in)) +} + +func corruptReceive(b *bridgeConn) { + log.Println("corruptReceive", b.String()) + go b.d.Copy(b.in, makeFetchCorrupt(makeFetch(b.out))) + b.d.Copy(b.out, makeFetch(b.in)) +} + +func corruptSend(b *bridgeConn) { + log.Println("corruptSend", b.String()) + go b.d.Copy(b.out, makeFetchCorrupt(makeFetch(b.in))) + b.d.Copy(b.in, makeFetch(b.out)) +} + +func makeFetch(c io.Reader) fetchFunc { + return func() ([]byte, error) { b := make([]byte, 4096) - n, err := outconn.Read(b) + n, err := c.Read(b) if err != nil { - return - } - _, err = conn.Write(b[:n]) - if err != nil { - return + return nil, err } + return b[:n], nil } } -func randomBlackhole(conn net.Conn, remoteAddr string) { - outconn, err := net.Dial("tcp", flag.Args()[1]) - if err != nil { - log.Println("oops:", err) - return +func makeFetchCorrupt(f func() ([]byte, error)) fetchFunc { + return func() ([]byte, error) { + b, err := f() + if err != nil { + return nil, err + } + // corrupt one byte approximately every 16K + for i := 0; i < len(b); i++ { + if rand.Intn(16*1024) == 0 { + b[i] = b[i] + 1 + } + } + return b, nil } - log.Printf("random blackhole: connection %v <-/-> %v\n", outconn.LocalAddr(), outconn.RemoteAddr()) +} + +func makeFetchRand(f func() ([]byte, error)) fetchFunc { + return func() ([]byte, error) { + if rand.Intn(10) == 0 { + return nil, fmt.Errorf("fetchRand: done") + } + b, err := f() + if err != nil { + return nil, err + } + return b, nil + } +} + +func randomBlackhole(b *bridgeConn) { + log.Println("random blackhole: connection", b.String()) var wg sync.WaitGroup wg.Add(2) go func() { - randCopy(conn, outconn) + b.d.Copy(b.in, makeFetchRand(makeFetch(b.out))) wg.Done() }() go func() { - randCopy(outconn, conn) + b.d.Copy(b.out, makeFetchRand(makeFetch(b.in))) wg.Done() }() wg.Wait() - conn.Close() - outconn.Close() + b.Close() } type config struct { @@ -111,10 +163,13 @@ type config struct { writeRemoteOnly bool readRemoteOnly bool randomBlackhole bool + corruptSend bool + corruptReceive bool + reorder bool } type acceptFaultFunc func() -type connFaultFunc func(net.Conn) +type connFaultFunc func(*bridgeConn) func main() { var cfg config @@ -128,7 +183,10 @@ func main() { flag.BoolVar(&cfg.timeClose, "time-close", true, "close after random time") flag.BoolVar(&cfg.writeRemoteOnly, "write-remote-only", true, "only write, no read") flag.BoolVar(&cfg.readRemoteOnly, "read-remote-only", true, "only read, no write") - flag.BoolVar(&cfg.randomBlackhole, "random-blockhole", true, "blackhole after data xfer") + flag.BoolVar(&cfg.randomBlackhole, "random-blackhole", true, "blackhole after data xfer") + flag.BoolVar(&cfg.corruptReceive, "corrupt-receive", true, "corrupt packets received from destination") + flag.BoolVar(&cfg.corruptSend, "corrupt-send", true, "corrupt packets sent to destination") + flag.BoolVar(&cfg.reorder, "reorder", true, "reorder packet delivery") flag.Parse() lAddr := flag.Args()[0] @@ -163,11 +221,11 @@ func main() { acceptFaults = append(acceptFaults, f) } - connFaults := []connFaultFunc{func(c net.Conn) { bridge(c, fwdAddr) }} + connFaults := []connFaultFunc{func(b *bridgeConn) { bridge(b) }} if cfg.immediateClose { - f := func(c net.Conn) { - log.Println("terminating connection immediately") - c.Close() + f := func(b *bridgeConn) { + log.Printf("terminating connection %s immediately", b.String()) + b.Close() } connFaults = append(connFaults, f) } @@ -175,31 +233,29 @@ func main() { connFaults = append(connFaults, blackhole) } if cfg.timeClose { - f := func(c net.Conn) { - go func() { - t := time.Duration(rand.Intn(5)+1) * time.Second - time.Sleep(t) - log.Printf("killing connection %v <-> %v after %v\n", - c.LocalAddr(), - c.RemoteAddr(), - t) - c.Close() - }() - bridge(c, fwdAddr) - } - connFaults = append(connFaults, f) + connFaults = append(connFaults, timeBridge) } if cfg.writeRemoteOnly { - f := func(c net.Conn) { writeRemoteOnly(c, fwdAddr) } - connFaults = append(connFaults, f) + connFaults = append(connFaults, writeRemoteOnly) } if cfg.readRemoteOnly { - f := func(c net.Conn) { readRemoteOnly(c, fwdAddr) } - connFaults = append(connFaults, f) + connFaults = append(connFaults, readRemoteOnly) } if cfg.randomBlackhole { - f := func(c net.Conn) { randomBlackhole(c, fwdAddr) } - connFaults = append(connFaults, f) + connFaults = append(connFaults, randomBlackhole) + } + if cfg.corruptSend { + connFaults = append(connFaults, corruptSend) + } + if cfg.corruptReceive { + connFaults = append(connFaults, corruptReceive) + } + + var disp dispatcher + if cfg.reorder { + disp = newDispatcherPool() + } else { + disp = newDispatcherImmediate() } for { @@ -213,7 +269,12 @@ func main() { if rand.Intn(100) > int(100.0*cfg.connFaultRate) { r = 0 } - go connFaults[r](conn) - } + bc, err := newBridgeConn(conn, disp) + if err != nil { + log.Printf("oops %v", err) + continue + } + go connFaults[r](bc) + } } diff --git a/tools/local-tester/bridge/dispatch.go b/tools/local-tester/bridge/dispatch.go new file mode 100644 index 000000000..b385cefe0 --- /dev/null +++ b/tools/local-tester/bridge/dispatch.go @@ -0,0 +1,140 @@ +// Copyright 2016 The etcd Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "io" + "math/rand" + "sync" + "time" +) + +var ( + // dispatchPoolDelay is the time to wait before flushing all buffered packets + dispatchPoolDelay = 100 * time.Millisecond + // dispatchPacketBytes is how many bytes to send until choosing a new connection + dispatchPacketBytes = 32 +) + +type dispatcher interface { + // Copy works like io.Copy using buffers provided by fetchFunc + Copy(io.Writer, fetchFunc) error +} + +type fetchFunc func() ([]byte, error) + +type dispatcherPool struct { + // mu protects the dispatch packet queue 'q' + mu sync.Mutex + q []dispatchPacket +} + +type dispatchPacket struct { + buf []byte + out io.Writer +} + +func newDispatcherPool() dispatcher { + d := &dispatcherPool{} + go d.writeLoop() + return d +} + +func (d *dispatcherPool) writeLoop() { + for { + time.Sleep(dispatchPoolDelay) + d.flush() + } +} + +func (d *dispatcherPool) flush() { + d.mu.Lock() + pkts := d.q + d.q = nil + d.mu.Unlock() + if len(pkts) == 0 { + return + } + + // sort by sockets; preserve the packet ordering within a socket + pktmap := make(map[io.Writer][]dispatchPacket) + outs := []io.Writer{} + for _, pkt := range pkts { + opkts, ok := pktmap[pkt.out] + if !ok { + outs = append(outs, pkt.out) + } + pktmap[pkt.out] = append(opkts, pkt) + } + + // send all packets in pkts + for len(outs) != 0 { + // randomize writer on every write + r := rand.Intn(len(outs)) + rpkts := pktmap[outs[r]] + rpkts[0].out.Write(rpkts[0].buf) + // dequeue packet + rpkts = rpkts[1:] + if len(rpkts) == 0 { + delete(pktmap, outs[r]) + outs = append(outs[:r], outs[r+1:]...) + } else { + pktmap[outs[r]] = rpkts + } + } +} + +func (d *dispatcherPool) Copy(w io.Writer, f fetchFunc) error { + for { + b, err := f() + if err != nil { + return err + } + + pkts := []dispatchPacket{} + for len(b) > 0 { + pkt := b + if len(b) > dispatchPacketBytes { + pkt = pkt[:dispatchPacketBytes] + b = b[dispatchPacketBytes:] + } else { + b = nil + } + pkts = append(pkts, dispatchPacket{pkt, w}) + } + + d.mu.Lock() + d.q = append(d.q, pkts...) + d.mu.Unlock() + } +} + +type dispatcherImmediate struct{} + +func newDispatcherImmediate() dispatcher { + return &dispatcherImmediate{} +} + +func (d *dispatcherImmediate) Copy(w io.Writer, f fetchFunc) error { + for { + b, err := f() + if err != nil { + return err + } + if _, err := w.Write(b); err != nil { + return err + } + } +}