From 6f32b2d57642b81c1f8eef63899536a921a2d875 Mon Sep 17 00:00:00 2001 From: Xiang Li Date: Fri, 27 Sep 2013 21:24:33 -0700 Subject: [PATCH] fix timeout --- etcd.go | 7 +--- raft_server.go | 14 +++++-- transporter.go | 97 +++++++++++++++++++++++---------------------- transporter_test.go | 47 +++++++++++++++++----- 4 files changed, 96 insertions(+), 69 deletions(-) diff --git a/etcd.go b/etcd.go index 7149a1af2..b0a2ccb5c 100644 --- a/etcd.go +++ b/etcd.go @@ -90,12 +90,7 @@ func init() { const ( ElectionTimeout = 200 * time.Millisecond HeartbeatTimeout = 50 * time.Millisecond - - // Timeout for internal raft http connection - // The original timeout for http is 45 seconds - // which is too long for our usage. - HTTPTimeout = 10 * time.Second - RetryInterval = 10 + RetryInterval = 10 ) //------------------------------------------------------------------------------ diff --git a/raft_server.go b/raft_server.go index fc204546a..a777515c4 100644 --- a/raft_server.go +++ b/raft_server.go @@ -33,7 +33,7 @@ var r *raftServer func newRaftServer(name string, url string, listenHost string, tlsConf *TLSConfig, tlsInfo *TLSInfo) *raftServer { // Create transporter for raft - raftTransporter := newTransporter(tlsConf.Scheme, tlsConf.Client, ElectionTimeout) + raftTransporter := newTransporter(tlsConf.Scheme, tlsConf.Client) // Create raft server server, err := raft.NewServer(name, dirPath, raftTransporter, etcdStore, nil, "") @@ -185,13 +185,16 @@ func (r *raftServer) startTransport(scheme string, tlsConf tls.Config) { // will need to do something more sophisticated later when we allow mixed // version clusters. func getVersion(t *transporter, versionURL url.URL) (string, error) { - resp, err := t.Get(versionURL.String()) + resp, req, err := t.Get(versionURL.String()) if err != nil { return "", err } defer resp.Body.Close() + + t.CancelWhenTimeout(req) + body, err := ioutil.ReadAll(resp.Body) return string(body), nil @@ -246,7 +249,7 @@ func joinByMachine(s *raft.Server, machine string, scheme string) error { debugf("Send Join Request to %s", joinURL.String()) - resp, err := t.Post(joinURL.String(), &b) + resp, req, err := t.Post(joinURL.String(), &b) for { if err != nil { @@ -254,6 +257,9 @@ func joinByMachine(s *raft.Server, machine string, scheme string) error { } if resp != nil { defer resp.Body.Close() + + t.CancelWhenTimeout(req) + if resp.StatusCode == http.StatusOK { b, _ := ioutil.ReadAll(resp.Body) r.joinIndex, _ = binary.Uvarint(b) @@ -266,7 +272,7 @@ func joinByMachine(s *raft.Server, machine string, scheme string) error { json.NewEncoder(&b).Encode(newJoinCommand()) - resp, err = t.Post(address, &b) + resp, req, err = t.Post(address, &b) } else if resp.StatusCode == http.StatusBadRequest { debug("Reach max number machines in the cluster") diff --git a/transporter.go b/transporter.go index c17c9d35f..6a4302013 100644 --- a/transporter.go +++ b/transporter.go @@ -13,26 +13,33 @@ import ( "github.com/coreos/go-raft" ) +// Timeout for setup internal raft http connection +// This should not exceed 3 * RTT +var dailTimeout = 3 * HeartbeatTimeout + +// Timeout for setup internal raft http connection + receive response header +// This should not exceed 3 * RTT + RTT +var responseHeaderTimeout = 4 * HeartbeatTimeout + +// Timeout for actually read the response body from the server +// This hould not exceed election timeout +var tranTimeout = ElectionTimeout + // Transporter layer for communication between raft nodes type transporter struct { - client *http.Client - timeout time.Duration -} - -// response struct -type transporterResponse struct { - resp *http.Response - err error + client *http.Client + transport *http.Transport } // Create transporter using by raft server // Create http or https transporter based on // whether the user give the server cert and key -func newTransporter(scheme string, tlsConf tls.Config, timeout time.Duration) *transporter { +func newTransporter(scheme string, tlsConf tls.Config) *transporter { t := transporter{} tr := &http.Transport{ - Dial: dialTimeout, + Dial: dialWithTimeout, + ResponseHeaderTimeout: responseHeaderTimeout, } if scheme == "https" { @@ -41,14 +48,14 @@ func newTransporter(scheme string, tlsConf tls.Config, timeout time.Duration) *t } t.client = &http.Client{Transport: tr} - t.timeout = timeout + t.transport = tr return &t } // Dial with timeout -func dialTimeout(network, addr string) (net.Conn, error) { - return net.DialTimeout(network, addr, HTTPTimeout) +func dialWithTimeout(network, addr string) (net.Conn, error) { + return net.DialTimeout(network, addr, dailTimeout) } // Sends AppendEntries RPCs to a peer when the server is the leader. @@ -76,7 +83,7 @@ func (t *transporter) SendAppendEntriesRequest(server *raft.Server, peer *raft.P start := time.Now() - resp, err := t.Post(fmt.Sprintf("%s/log/append", u), &b) + resp, httpRequest, err := t.Post(fmt.Sprintf("%s/log/append", u), &b) end := time.Now() @@ -93,6 +100,9 @@ func (t *transporter) SendAppendEntriesRequest(server *raft.Server, peer *raft.P if resp != nil { defer resp.Body.Close() + + t.CancelWhenTimeout(httpRequest) + aersp = &raft.AppendEntriesResponse{} if err := json.NewDecoder(resp.Body).Decode(&aersp); err == nil || err == io.EOF { return aersp @@ -112,7 +122,7 @@ func (t *transporter) SendVoteRequest(server *raft.Server, peer *raft.Peer, req u, _ := nameToRaftURL(peer.Name) debugf("Send Vote to %s", u) - resp, err := t.Post(fmt.Sprintf("%s/vote", u), &b) + resp, httpRequest, err := t.Post(fmt.Sprintf("%s/vote", u), &b) if err != nil { debugf("Cannot send VoteRequest to %s : %s", u, err) @@ -120,6 +130,9 @@ func (t *transporter) SendVoteRequest(server *raft.Server, peer *raft.Peer, req if resp != nil { defer resp.Body.Close() + + t.CancelWhenTimeout(httpRequest) + rvrsp := &raft.RequestVoteResponse{} if err := json.NewDecoder(resp.Body).Decode(&rvrsp); err == nil || err == io.EOF { return rvrsp @@ -139,7 +152,7 @@ func (t *transporter) SendSnapshotRequest(server *raft.Server, peer *raft.Peer, debugf("Send Snapshot to %s [Last Term: %d, LastIndex %d]", u, req.LastTerm, req.LastIndex) - resp, err := t.Post(fmt.Sprintf("%s/snapshot", u), &b) + resp, httpRequest, err := t.Post(fmt.Sprintf("%s/snapshot", u), &b) if err != nil { debugf("Cannot send SendSnapshotRequest to %s : %s", u, err) @@ -147,6 +160,9 @@ func (t *transporter) SendSnapshotRequest(server *raft.Server, peer *raft.Peer, if resp != nil { defer resp.Body.Close() + + t.CancelWhenTimeout(httpRequest) + aersp = &raft.SnapshotResponse{} if err = json.NewDecoder(resp.Body).Decode(&aersp); err == nil || err == io.EOF { @@ -167,7 +183,7 @@ func (t *transporter) SendSnapshotRecoveryRequest(server *raft.Server, peer *raf debugf("Send SnapshotRecovery to %s [Last Term: %d, LastIndex %d]", u, req.LastTerm, req.LastIndex) - resp, err := t.Post(fmt.Sprintf("%s/snapshotRecovery", u), &b) + resp, _, err := t.Post(fmt.Sprintf("%s/snapshotRecovery", u), &b) if err != nil { debugf("Cannot send SendSnapshotRecoveryRequest to %s : %s", u, err) @@ -176,6 +192,7 @@ func (t *transporter) SendSnapshotRecoveryRequest(server *raft.Server, peer *raf if resp != nil { defer resp.Body.Close() aersp = &raft.SnapshotRecoveryResponse{} + if err = json.NewDecoder(resp.Body).Decode(&aersp); err == nil || err == io.EOF { return aersp } @@ -185,46 +202,30 @@ func (t *transporter) SendSnapshotRecoveryRequest(server *raft.Server, peer *raf } // Send server side POST request -func (t *transporter) Post(path string, body io.Reader) (*http.Response, error) { +func (t *transporter) Post(urlStr string, body io.Reader) (*http.Response, *http.Request, error) { - c := make(chan *transporterResponse, 1) + req, _ := http.NewRequest("POST", urlStr, body) - go func() { - tr := new(transporterResponse) - tr.resp, tr.err = t.client.Post(path, "application/json", body) - c <- tr - }() + resp, err := t.client.Do(req) - return t.waitResponse(c) + return resp, req, err } // Send server side GET request -func (t *transporter) Get(path string) (*http.Response, error) { +func (t *transporter) Get(urlStr string) (*http.Response, *http.Request, error) { - c := make(chan *transporterResponse, 1) + req, _ := http.NewRequest("GET", urlStr, nil) + resp, err := t.client.Do(req) + + return resp, req, err +} + +// Cancel the on fly HTTP transaction when timeout happens +func (t *transporter) CancelWhenTimeout(req *http.Request) { go func() { - tr := new(transporterResponse) - tr.resp, tr.err = t.client.Get(path) - c <- tr + time.Sleep(ElectionTimeout) + t.transport.CancelRequest(req) }() - - return t.waitResponse(c) -} - -func (t *transporter) waitResponse(responseChan chan *transporterResponse) (*http.Response, error) { - - timeoutChan := time.After(t.timeout * 10) - - select { - case <-timeoutChan: - return nil, fmt.Errorf("Wait Response Timeout: %v", t.timeout) - - case r := <-responseChan: - return r.resp, r.err - } - - // for complier - return nil, nil } diff --git a/transporter_test.go b/transporter_test.go index e440a094f..8c71325c6 100644 --- a/transporter_test.go +++ b/transporter_test.go @@ -2,33 +2,58 @@ package main import ( "crypto/tls" + "fmt" + "io/ioutil" + "net/http" "testing" "time" ) func TestTransporterTimeout(t *testing.T) { + http.HandleFunc("/timeout", func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, "timeout") + w.(http.Flusher).Flush() // send headers and some body + time.Sleep(time.Second * 100) + }) + + go http.ListenAndServe(":8080", nil) + conf := tls.Config{} - ts := newTransporter("http", conf, time.Second) + ts := newTransporter("http", conf) ts.Get("http://google.com") - _, err := ts.Get("http://google.com:9999") // it doesn't exisit - if err == nil || err.Error() != "Wait Response Timeout: 1s" { - t.Fatal("timeout error: ", err.Error()) + _, _, err := ts.Get("http://google.com:9999") + if err == nil { + t.Fatal("timeout error") } - _, err = ts.Post("http://google.com:9999", nil) // it doesn't exisit - if err == nil || err.Error() != "Wait Response Timeout: 1s" { - t.Fatal("timeout error: ", err.Error()) - } + res, req, err := ts.Get("http://localhost:8080/timeout") - _, err = ts.Get("http://www.google.com") if err != nil { - t.Fatal("get error") + t.Fatal("should not timeout") } - _, err = ts.Post("http://www.google.com", nil) + ts.CancelWhenTimeout(req) + + body, err := ioutil.ReadAll(res.Body) + if err == nil { + fmt.Println(string(body)) + t.Fatal("expected an error reading the body") + } + + _, _, err = ts.Post("http://google.com:9999", nil) + if err == nil { + t.Fatal("timeout error") + } + + _, _, err = ts.Get("http://www.google.com") + if err != nil { + t.Fatal("get error: ", err.Error()) + } + + _, _, err = ts.Post("http://www.google.com", nil) if err != nil { t.Fatal("post error") }