diff --git a/README.md b/README.md index a701516..0f25976 100644 --- a/README.md +++ b/README.md @@ -39,19 +39,19 @@ requests should be sent. Baton will wait for all the responses to be received be ##### Requests file -When specifying a file to load requests from (`-z filename`), the file should have the following format: +When specifying a file to load requests from (`-z filename`), the file should be of CSV format ([RFC-4180](https://tools.ietf.org/html/rfc4180)) ``` -Method>Uri>Body(Optional) +Method,URL,Body,Headers ... ``` +The headers column can contain multiple headers separated by newline, however make sure the whole column is quoted as per the RFC. + For example: ``` -POST>http://localhost:8080>Data 1 -POST>http://localhost:8080>Data 2 -POST>http://localhost:8080>Data 3 -POST>http://localhost:8080>Data 4 -GET>http://localhost:8080> +POST,http://localhost:8888,body,"Accept: application/xml +Content-type: Secret" +GET,http://localhost:8888,, ``` ##### Example Output: diff --git a/baton.go b/baton.go index 0d42226..4906a70 100644 --- a/baton.go +++ b/baton.go @@ -1,14 +1,12 @@ package main import ( - "bufio" "crypto/tls" "flag" "github.com/valyala/fasthttp" "io/ioutil" "log" "os" - "strings" "time" ) @@ -21,7 +19,7 @@ var ( method = flag.String("m", "GET", "HTTP Method (GET,POST,PUT,DELETE)") numberOfRequests = flag.Int("r", 1, "Number of requests (use instead of -t)") requestsFromFile = flag.String("z", "", "Read requests from a file") - supressOutput = flag.Bool("o", false, "Supress output, no results will be printed to stdout") + suppressOutput = flag.Bool("o", false, "Suppress output, no results will be printed to stdout") url = flag.String("u", "", "URL to run against") wait = flag.Int("w", 0, "Number of seconds to wait before running test") ) @@ -36,7 +34,7 @@ type Configuration struct { method string numberOfRequests int requestsFromFile string - supressOutput bool + suppressOutput bool url string wait int } @@ -48,9 +46,14 @@ type Baton struct { } type preloadedRequest struct { - method string - url string - body string + // The HTTP method used to send the request + method string + // The URL to send the request at + url string + // The body of the request (if appropriate method is selected) + body string + // Array of two-element key/value pairs of header and value + headers [][]string } func main() { @@ -65,7 +68,7 @@ func main() { *method, *numberOfRequests, *requestsFromFile, - *supressOutput, + *suppressOutput, *url, *wait, } @@ -76,30 +79,11 @@ func main() { baton.result.printResults() } -func preloadRequestsFromFile(filename string) ([]preloadedRequest, error) { - file, err := os.Open(filename) - if err != nil { - return nil, err - } - defer file.Close() - var requests []preloadedRequest - scanner := bufio.NewScanner(file) - for scanner.Scan() { - line := scanner.Text() - parts := strings.SplitN(line, ">", 3) - method := parts[0] - url := parts[1] - body := parts[2] - requests = append(requests, preloadedRequest{method, url, body}) - } - return requests, scanner.Err() -} - func (baton *Baton) run() { logWriter := &logWriter{true} - if baton.configuration.supressOutput { + if baton.configuration.suppressOutput { logWriter.Disable() } @@ -183,7 +167,8 @@ func (baton *Baton) run() { if preloadedRequestsMode { go worker.sendRequests(preloadedRequests) } else { - go worker.sendRequest(preloadedRequest{baton.configuration.method, baton.configuration.url, baton.configuration.body}) + request := preloadedRequest{baton.configuration.method, baton.configuration.url, baton.configuration.body, [][]string{}} + go worker.sendRequest(request) } } diff --git a/baton_test.go b/baton_test.go index cec71f2..effb0e5 100644 --- a/baton_test.go +++ b/baton_test.go @@ -12,11 +12,12 @@ import ( ) type HTTPTestHandler struct { - noRequestsReceived uint32 - lastBodyReceived string - lastMethodReceived string - lastURIReceived string - lastTimestamp int64 + noRequestsReceived uint32 + lastBodyReceived string + lastMethodReceived string + lastURIReceived string + lastHeadersReceived fasthttp.RequestHeader + lastTimestamp int64 } func (h *HTTPTestHandler) HandleRequest(ctx *fasthttp.RequestCtx) { @@ -25,6 +26,7 @@ func (h *HTTPTestHandler) HandleRequest(ctx *fasthttp.RequestCtx) { h.lastBodyReceived = hex.EncodeToString(ctx.Request.Body()) h.lastMethodReceived = string(ctx.Request.Header.Method()) h.lastURIReceived = ctx.Request.URI().String() + h.lastHeadersReceived = ctx.Request.Header } func (h *HTTPTestHandler) reset() { @@ -41,7 +43,7 @@ var port = "8888" func startServer() *HTTPTestHandler { if !serverRunning { - internalHandlerRef = &HTTPTestHandler{0, "", "", "", 0} + internalHandlerRef = &HTTPTestHandler{0, "", "", "", fasthttp.RequestHeader{}, 0} serverRunning = true go func() { err := fasthttp.ListenAndServe(":"+port, internalHandlerRef.HandleRequest) @@ -171,7 +173,7 @@ func TestLoadPostFromTextFile(t *testing.T) { func TestPostRequestLoadedFromFile(t *testing.T) { uri := "http://localhost:" + port method := "POST" - fileContents := method + ">" + uri + ">" + "Data" + fileContents := method + "," + uri + "," + "Data" fileInBytes := []byte(fileContents) fileDir := "test-resources/requests-from-file.txt" @@ -198,6 +200,36 @@ func TestPostRequestLoadedFromFile(t *testing.T) { } } +func TestThatHeadersAreSetWhenSendingFromFile(t *testing.T) { + uri := "http://localhost:" + port + method := "GET" + fileContents := method + "," + uri + "," + "" + "," +"\"Content-Type: Hello\r\nSecret: World\"" + fileInBytes := []byte(fileContents) + + fileDir := "test-resources/requests-from-file.txt" + if ioutil.WriteFile(fileDir, fileInBytes, 0644) != nil { + t.Errorf("Failed to write a required test case file. Check the directory permissions.") + } + defer os.Remove(fileDir) + + config := defaultConfig() + config.requestsFromFile = fileDir + config.numberOfRequests = 1 + testHandler := setupAndListen(config) + + headerActual := hex.EncodeToString(testHandler.lastHeadersReceived.Peek("Content-Type")) + headerExpected := hex.EncodeToString([]byte("Hello")) + if headerExpected != headerActual { + t.Errorf("Header not found or improperly set, Expected %s, got %s", headerExpected, headerActual) + } + + headerActual2 := hex.EncodeToString(testHandler.lastHeadersReceived.Peek("Secret")) + headerExpected2 := hex.EncodeToString([]byte("World")) + if headerExpected != headerActual { + t.Errorf("Header not found or improperly set, Expected %s, got %s", headerExpected2, headerActual2) + } +} + func TestThatTimeOptionRunsForCorrectAmountOfTime(t *testing.T) { duration := 10 testHandler := startServer() @@ -218,3 +250,4 @@ func TestThatTimeOptionRunsForCorrectAmountOfTime(t *testing.T) { t.Errorf("Requests sent for longer/shorter than expected. Expected %d, got %d)", duration, diff) } } + diff --git a/csv_parsing.go b/csv_parsing.go new file mode 100644 index 0000000..a18c841 --- /dev/null +++ b/csv_parsing.go @@ -0,0 +1,73 @@ +package main + +import ( + "os" + "encoding/csv" + "bufio" + "io" + "strings" + "runtime" + "errors" +) + +func extractHeaders(rawHeaders string) [][]string { + var headers [][]string + if rawHeaders != "" { + header := strings.Split(rawHeaders, "\n") + for i := 0; i < len(header); i++ { + headerParts := strings.Split(header[i], ":") + if len(headerParts) == 2 { + headers = append(headers, []string{headerParts[0], headerParts[1]}) + } + } + } + return headers +} + +func preloadRequestsFromFile(filename string) ([]preloadedRequest, error) { + file, err := os.Open(filename) + + if err != nil { + return nil, err + } + + reader := csv.NewReader(bufio.NewReader(file)) + var requests []preloadedRequest + + for { + record, err := reader.Read() + if err == io.EOF { + break + } + if err != nil { + return nil, err + } + + var method = "" + var url = "" + var body = "" + var headers [][]string + noFields := len(record) + + if noFields < 2 { + return nil, errors.New("invalid number of fields") + } + + if noFields >= 2{ + method = record[0] + url = record[1] + } + + if noFields >= 3 { + body = record[2] + } + + if noFields >= 4 { + headers = extractHeaders(record[3]) + } + + requests = append(requests, preloadedRequest{method, url, body, headers}) + } + + return requests, nil +} diff --git a/worker.go b/worker.go index 98d0e63..8207412 100644 --- a/worker.go +++ b/worker.go @@ -59,6 +59,9 @@ func buildRequest(requests []preloadedRequest, totalPremadeRequests int) (*fasth req.SetRequestURI(currentReq.url) req.Header.SetMethod(currentReq.method) req.SetBodyString(currentReq.body) + for i := 0; i < len(currentReq.headers); i++ { + req.Header.Add(currentReq.headers[i][0], currentReq.headers[i][1]) + } return req, resp }