diff --git a/raft_server.go b/raft_server.go index b3b5d3277..fa3cb3a39 100644 --- a/raft_server.go +++ b/raft_server.go @@ -16,13 +16,13 @@ import ( type raftServer struct { *raft.Server - version string - joinIndex uint64 - name string - url string + version string + joinIndex uint64 + name string + url string listenHost string - tlsConf *TLSConfig - tlsInfo *TLSInfo + tlsConf *TLSConfig + tlsInfo *TLSInfo } var r *raftServer @@ -30,7 +30,7 @@ var r *raftServer func newRaftServer(name string, url string, listenHost string, tlsConf *TLSConfig, tlsInfo *TLSInfo) *raftServer { // Create transporter for raft - raftTransporter := newTransporter(tlsConf.Scheme, tlsConf.Client) + raftTransporter := newTransporter(tlsConf.Scheme, tlsConf.Client, raft.DefaultHeartbeatTimeout) // Create raft server server, err := raft.NewServer(name, dirPath, raftTransporter, etcdStore, nil) @@ -38,13 +38,13 @@ func newRaftServer(name string, url string, listenHost string, tlsConf *TLSConfi check(err) return &raftServer{ - Server: server, - version: raftVersion, - name: name, - url: url, + Server: server, + version: raftVersion, + name: name, + url: url, listenHost: listenHost, - tlsConf: tlsConf, - tlsInfo: tlsInfo, + tlsConf: tlsConf, + tlsInfo: tlsInfo, } } diff --git a/transporter.go b/transporter.go index c49479bc8..66a179791 100644 --- a/transporter.go +++ b/transporter.go @@ -9,17 +9,19 @@ import ( "io" "net" "net/http" + "time" ) // Transporter layer for communication between raft nodes type transporter struct { - client *http.Client + client *http.Client + timeout time.Duration } // Create transporter using by raft server // Create http or https transporter based on // whether the user give the server cert and key -func newTransporter(scheme string, tlsConf tls.Config) transporter { +func newTransporter(scheme string, tlsConf tls.Config, timeout time.Duration) transporter { t := transporter{} tr := &http.Transport{ @@ -32,6 +34,7 @@ func newTransporter(scheme string, tlsConf tls.Config) transporter { } t.client = &http.Client{Transport: tr} + t.timeout = timeout return t } @@ -151,10 +154,58 @@ func (t transporter) SendSnapshotRecoveryRequest(server *raft.Server, peer *raft // Send server side POST request func (t transporter) Post(path string, body io.Reader) (*http.Response, error) { - return t.client.Post(path, "application/json", body) + + postChan := make(chan interface{}, 1) + + go func() { + resp, err := t.client.Post(path, "application/json", body) + if err == nil { + postChan <- resp + } else { + postChan <- err + } + }() + + return t.waitResponse(postChan) + } // Send server side GET request func (t transporter) Get(path string) (*http.Response, error) { - return t.client.Get(path) + + getChan := make(chan interface{}, 1) + + go func() { + resp, err := t.client.Get(path) + if err == nil { + getChan <- resp + } else { + getChan <- err + } + }() + + return t.waitResponse(getChan) +} + +func (t transporter) waitResponse(responseChan chan interface{}) (*http.Response, error) { + + timeoutChan := time.After(t.timeout) + + select { + case <-timeoutChan: + return nil, fmt.Errorf("Wait Response Timeout: %v", t.timeout) + + case r := <-responseChan: + switch r := r.(type) { + case error: + return nil, r + + case *http.Response: + return r, nil + + } + } + + // for complier + return nil, nil } diff --git a/transporter_test.go b/transporter_test.go new file mode 100644 index 000000000..88fefdcf7 --- /dev/null +++ b/transporter_test.go @@ -0,0 +1,35 @@ +package main + +import ( + "crypto/tls" + "testing" + "time" +) + +func TestTransporterTimeout(t *testing.T) { + + conf := tls.Config{} + + ts := newTransporter("http", conf, time.Second) + + _, err := ts.Get("http://127.0.0.2:7000") + if err == nil || err.Error() != "Wait Response Timeout: 1s" { + t.Fatal("timeout error: ", err.Error()) + } + + _, err = ts.Post("http://127.0.0.2:7000", nil) + if err == nil || err.Error() != "Wait Response Timeout: 1s" { + t.Fatal("timeout error: ", err.Error()) + } + + _, err = ts.Get("http://www.google.com") + if err != nil { + t.Fatal("get error") + } + + _, err = ts.Post("http://www.google.com", nil) + if err != nil { + t.Fatal("post error") + } + +}