diff --git a/raft_server.go b/raft_server.go index b3b5d3277..c8b86021c 100644 --- a/raft_server.go +++ b/raft_server.go @@ -16,13 +16,13 @@ import ( type raftServer struct { *raft.Server - version string - joinIndex uint64 - name string - url string + version string + joinIndex uint64 + name string + url string listenHost string - tlsConf *TLSConfig - tlsInfo *TLSInfo + tlsConf *TLSConfig + tlsInfo *TLSInfo } var r *raftServer @@ -30,7 +30,7 @@ var r *raftServer func newRaftServer(name string, url string, listenHost string, tlsConf *TLSConfig, tlsInfo *TLSInfo) *raftServer { // Create transporter for raft - raftTransporter := newTransporter(tlsConf.Scheme, tlsConf.Client) + raftTransporter := newTransporter(tlsConf.Scheme, tlsConf.Client, ElectionTimeout) // Create raft server server, err := raft.NewServer(name, dirPath, raftTransporter, etcdStore, nil) @@ -38,13 +38,13 @@ func newRaftServer(name string, url string, listenHost string, tlsConf *TLSConfi check(err) return &raftServer{ - Server: server, - version: raftVersion, - name: name, - url: url, + Server: server, + version: raftVersion, + name: name, + url: url, listenHost: listenHost, - tlsConf: tlsConf, - tlsInfo: tlsInfo, + tlsConf: tlsConf, + tlsInfo: tlsInfo, } } @@ -169,7 +169,7 @@ func (r *raftServer) startTransport(scheme string, tlsConf tls.Config) { // getVersion fetches the raft version of a peer. This works for now but we // will need to do something more sophisticated later when we allow mixed // version clusters. -func getVersion(t transporter, versionURL url.URL) (string, error) { +func getVersion(t *transporter, versionURL url.URL) (string, error) { resp, err := t.Get(versionURL.String()) if err != nil { @@ -198,6 +198,7 @@ func joinCluster(cluster []string) bool { if _, ok := err.(etcdErr.Error); ok { fatal(err) } + debugf("cannot join to cluster via machine %s %s", machine, err) } } @@ -209,7 +210,7 @@ func joinByMachine(s *raft.Server, machine string, scheme string) error { var b bytes.Buffer // t must be ok - t, _ := r.Transporter().(transporter) + t, _ := r.Transporter().(*transporter) // Our version must match the leaders version versionURL := url.URL{Host: machine, Scheme: scheme, Path: "/version"} diff --git a/transporter.go b/transporter.go index c49479bc8..b4564742c 100644 --- a/transporter.go +++ b/transporter.go @@ -9,17 +9,25 @@ import ( "io" "net" "net/http" + "time" ) // Transporter layer for communication between raft nodes type transporter struct { - client *http.Client + client *http.Client + timeout time.Duration +} + +// response struct +type transporterResponse struct { + resp *http.Response + err error } // Create transporter using by raft server // Create http or https transporter based on // whether the user give the server cert and key -func newTransporter(scheme string, tlsConf tls.Config) transporter { +func newTransporter(scheme string, tlsConf tls.Config, timeout time.Duration) *transporter { t := transporter{} tr := &http.Transport{ @@ -32,8 +40,9 @@ func newTransporter(scheme string, tlsConf tls.Config) transporter { } t.client = &http.Client{Transport: tr} + t.timeout = timeout - return t + return &t } // Dial with timeout @@ -42,7 +51,7 @@ func dialTimeout(network, addr string) (net.Conn, error) { } // Sends AppendEntries RPCs to a peer when the server is the leader. -func (t transporter) SendAppendEntriesRequest(server *raft.Server, peer *raft.Peer, req *raft.AppendEntriesRequest) *raft.AppendEntriesResponse { +func (t *transporter) SendAppendEntriesRequest(server *raft.Server, peer *raft.Peer, req *raft.AppendEntriesRequest) *raft.AppendEntriesResponse { var aersp *raft.AppendEntriesResponse var b bytes.Buffer json.NewEncoder(&b).Encode(req) @@ -69,7 +78,7 @@ func (t transporter) SendAppendEntriesRequest(server *raft.Server, peer *raft.Pe } // Sends RequestVote RPCs to a peer when the server is the candidate. -func (t transporter) SendVoteRequest(server *raft.Server, peer *raft.Peer, req *raft.RequestVoteRequest) *raft.RequestVoteResponse { +func (t *transporter) SendVoteRequest(server *raft.Server, peer *raft.Peer, req *raft.RequestVoteRequest) *raft.RequestVoteResponse { var rvrsp *raft.RequestVoteResponse var b bytes.Buffer json.NewEncoder(&b).Encode(req) @@ -95,7 +104,7 @@ func (t transporter) SendVoteRequest(server *raft.Server, peer *raft.Peer, req * } // Sends SnapshotRequest RPCs to a peer when the server is the candidate. -func (t transporter) SendSnapshotRequest(server *raft.Server, peer *raft.Peer, req *raft.SnapshotRequest) *raft.SnapshotResponse { +func (t *transporter) SendSnapshotRequest(server *raft.Server, peer *raft.Peer, req *raft.SnapshotRequest) *raft.SnapshotResponse { var aersp *raft.SnapshotResponse var b bytes.Buffer json.NewEncoder(&b).Encode(req) @@ -123,7 +132,7 @@ func (t transporter) SendSnapshotRequest(server *raft.Server, peer *raft.Peer, r } // Sends SnapshotRecoveryRequest RPCs to a peer when the server is the candidate. -func (t transporter) SendSnapshotRecoveryRequest(server *raft.Server, peer *raft.Peer, req *raft.SnapshotRecoveryRequest) *raft.SnapshotRecoveryResponse { +func (t *transporter) SendSnapshotRecoveryRequest(server *raft.Server, peer *raft.Peer, req *raft.SnapshotRecoveryRequest) *raft.SnapshotRecoveryResponse { var aersp *raft.SnapshotRecoveryResponse var b bytes.Buffer json.NewEncoder(&b).Encode(req) @@ -150,11 +159,46 @@ func (t transporter) SendSnapshotRecoveryRequest(server *raft.Server, peer *raft } // Send server side POST request -func (t transporter) Post(path string, body io.Reader) (*http.Response, error) { - return t.client.Post(path, "application/json", body) +func (t *transporter) Post(path string, body io.Reader) (*http.Response, error) { + + c := make(chan *transporterResponse, 1) + + go func() { + tr := new(transporterResponse) + tr.resp, tr.err = t.client.Post(path, "application/json", body) + c <- tr + }() + + return t.waitResponse(c) + } // Send server side GET request -func (t transporter) Get(path string) (*http.Response, error) { - return t.client.Get(path) +func (t *transporter) Get(path string) (*http.Response, error) { + + c := make(chan *transporterResponse, 1) + + go func() { + tr := new(transporterResponse) + tr.resp, tr.err = t.client.Get(path) + c <- tr + }() + + return t.waitResponse(c) +} + +func (t *transporter) waitResponse(responseChan chan *transporterResponse) (*http.Response, error) { + + timeoutChan := time.After(t.timeout) + + select { + case <-timeoutChan: + return nil, fmt.Errorf("Wait Response Timeout: %v", t.timeout) + + case r := <-responseChan: + return r.resp, r.err + } + + // for complier + return nil, nil } diff --git a/transporter_test.go b/transporter_test.go new file mode 100644 index 000000000..e440a094f --- /dev/null +++ b/transporter_test.go @@ -0,0 +1,36 @@ +package main + +import ( + "crypto/tls" + "testing" + "time" +) + +func TestTransporterTimeout(t *testing.T) { + + conf := tls.Config{} + + ts := newTransporter("http", conf, time.Second) + + ts.Get("http://google.com") + _, err := ts.Get("http://google.com:9999") // it doesn't exisit + if err == nil || err.Error() != "Wait Response Timeout: 1s" { + t.Fatal("timeout error: ", err.Error()) + } + + _, err = ts.Post("http://google.com:9999", nil) // it doesn't exisit + if err == nil || err.Error() != "Wait Response Timeout: 1s" { + t.Fatal("timeout error: ", err.Error()) + } + + _, err = ts.Get("http://www.google.com") + if err != nil { + t.Fatal("get error") + } + + _, err = ts.Post("http://www.google.com", nil) + if err != nil { + t.Fatal("post error") + } + +}