diff --git a/baton.go b/baton.go index 5a16e68..e7acd95 100644 --- a/baton.go +++ b/baton.go @@ -2,6 +2,7 @@ package main import ( "crypto/tls" + "errors" "flag" "github.com/valyala/fasthttp" "io/ioutil" @@ -23,34 +24,30 @@ var ( wait = flag.Int("w", 0, "Number of seconds to wait before running test") ) -// Configuration represents the Baton configuration -type Configuration struct { - body string - concurrency int - dataFilePath string - duration int - ignoreTLS bool - method string - numberOfRequests int - requestsFromFile string - suppressOutput bool - url string - wait int -} - // Baton implements the load tester type Baton struct { configuration Configuration result Result } -type preloadedRequest struct { +type preLoadedRequest struct { method string // The HTTP method used to send the request url string // The URL to send the request at body string // The body of the request (if appropriate method is selected) headers [][]string // Array of two-element key/value pairs of header and value } +type runConfiguration struct { + preLoadedRequestsMode bool + timedMode bool + preLoadedRequests []preLoadedRequest + client *fasthttp.Client + requests chan bool + results chan HTTPResult + done chan bool + body string +} + func main() { flag.Parse() @@ -76,93 +73,45 @@ func main() { func (baton *Baton) run() { - logWriter := &logWriter{true} + configureLogging(baton.configuration.suppressOutput) - if baton.configuration.suppressOutput { - logWriter.Disable() + err := baton.configuration.validate() + if err != nil { + log.Fatalf("Invalid configuration: %v", err) } - log.SetFlags(0) - log.SetOutput(logWriter) - - preloadedRequestsMode := false - timedMode := false - var preloadedRequests []preloadedRequest - var err error - - if baton.configuration.requestsFromFile != "" { - preloadedRequests, err = preloadRequestsFromFile(baton.configuration.requestsFromFile) - preloadedRequestsMode = true - if err != nil { - log.Fatal("Failed to parse requests from file: " + baton.configuration.requestsFromFile) - } - } - - if baton.configuration.duration != 0 { - timedMode = true - } - - if baton.configuration.concurrency < 1 || baton.configuration.numberOfRequests == 0 { - log.Fatal("Invalid concurrency level or number of requests") - } - - client := &fasthttp.Client{} - if baton.configuration.ignoreTLS { - tlsConfig := &tls.Config{InsecureSkipVerify: true} - client = &fasthttp.Client{TLSConfig: tlsConfig} - } - - if baton.configuration.dataFilePath != "" { - data, err := ioutil.ReadFile(baton.configuration.dataFilePath) - if err != nil { - log.Fatal(err.Error()) - } - baton.configuration.body = string(data) - } - - if preloadedRequestsMode { - log.Printf("Configuring to send requests from file. (Read %d requests)\n", len(preloadedRequests)) - } else { - log.Printf("Configuring to send %s requests to: %s\n", baton.configuration.method, baton.configuration.url) + preparedRunConfiguration, err := prepareRun(baton.configuration) + if err != nil { + log.Fatalf("Error during run preparation: %v", err) } if baton.configuration.wait > 0 { time.Sleep(time.Duration(baton.configuration.wait) * time.Second) } - requests := make(chan bool, baton.configuration.numberOfRequests) - results := make(chan HTTPResult, baton.configuration.concurrency) - done := make(chan bool, baton.configuration.concurrency) - - log.Println("Generating the requests...") - for r := 1; r <= baton.configuration.numberOfRequests; r++ { - requests <- true - } - close(requests) - log.Println("Finished generating the requests") log.Println("Sending the requests to the server...") // Start the timer and kick off the workers start := time.Now() for w := 1; w <= baton.configuration.concurrency; w++ { var worker workable - if timedMode { - worker = newTimedWorker(requests, results, done, float64(baton.configuration.duration)) + if preparedRunConfiguration.timedMode { + worker = newTimedWorker(preparedRunConfiguration.requests, preparedRunConfiguration.results, preparedRunConfiguration.done, float64(baton.configuration.duration)) } else { - worker = newCountWorker(requests, results, done) + worker = newCountWorker(preparedRunConfiguration.requests, preparedRunConfiguration.results, preparedRunConfiguration.done) } - worker.setCustomClient(client) - if preloadedRequestsMode { - go worker.sendRequests(preloadedRequests) + worker.setCustomClient(preparedRunConfiguration.client) + if preparedRunConfiguration.preLoadedRequestsMode { + go worker.sendRequests(preparedRunConfiguration.preLoadedRequests) } else { - request := preloadedRequest{baton.configuration.method, baton.configuration.url, baton.configuration.body, [][]string{}} + request := preLoadedRequest{baton.configuration.method, baton.configuration.url, preparedRunConfiguration.body, [][]string{}} go worker.sendRequest(request) } } // Wait for all the workers to finish and then stop the timer for a := 1; a <= baton.configuration.concurrency; a++ { - <-done + <-preparedRunConfiguration.done } baton.result.timeTaken = time.Since(start) @@ -170,7 +119,7 @@ func (baton *Baton) run() { log.Println("Processing the results...") for a := 1; a <= baton.configuration.concurrency; a++ { - result := <-results + result := <-preparedRunConfiguration.results baton.result.httpResult.connectionErrorCount += result.connectionErrorCount baton.result.httpResult.status1xxCount += result.status1xxCount baton.result.httpResult.status2xxCount += result.status2xxCount @@ -183,3 +132,81 @@ func (baton *Baton) run() { baton.result.requestsPerSecond = int(float64(baton.result.totalRequests)/baton.result.timeTaken.Seconds() + 0.5) } + +func configureLogging(suppressOutput bool) { + + logWriter := &logWriter{true} + + if suppressOutput { + logWriter.Disable() + } + + log.SetFlags(0) + log.SetOutput(logWriter) +} + +func prepareRun(configuration Configuration) (runConfiguration, error) { + + preLoadedRequestsMode := false + timedMode := false + + var preLoadedRequests []preLoadedRequest + + if configuration.requestsFromFile != "" { + var err error + preLoadedRequests, err = preLoadRequestsFromFile(configuration.requestsFromFile) + preLoadedRequestsMode = true + if err != nil { + return runConfiguration{}, errors.New("failed to parse requests from file: " + configuration.requestsFromFile) + } + } + + if configuration.duration != 0 { + timedMode = true + } + + client := &fasthttp.Client{} + if configuration.ignoreTLS { + tlsConfig := &tls.Config{InsecureSkipVerify: true} + client = &fasthttp.Client{TLSConfig: tlsConfig} + } + + body := configuration.body + if configuration.dataFilePath != "" { + data, err := ioutil.ReadFile(configuration.dataFilePath) + if err != nil { + return runConfiguration{}, err + } + body = string(data) + } + + if preLoadedRequestsMode { + log.Printf("Configuring to send requests from file. (Read %d requests)\n", len(preLoadedRequests)) + } else { + log.Printf("Configuring to send %s requests to: %s\n", configuration.method, configuration.url) + } + + requests := make(chan bool, configuration.numberOfRequests) + results := make(chan HTTPResult, configuration.concurrency) + done := make(chan bool, configuration.concurrency) + + log.Println("Generating the requests...") + for r := 1; r <= configuration.numberOfRequests; r++ { + requests <- true + } + close(requests) + log.Println("Finished generating the requests") + + preparedRunConfiguration := runConfiguration{ + preLoadedRequestsMode, + timedMode, + preLoadedRequests, + client, + requests, + results, + done, + body, + } + + return preparedRunConfiguration, nil +} diff --git a/configuration.go b/configuration.go new file mode 100644 index 0000000..0eaecd3 --- /dev/null +++ b/configuration.go @@ -0,0 +1,29 @@ +package main + +import ( + "errors" +) + +// Configuration represents the Baton configuration +type Configuration struct { + body string + concurrency int + dataFilePath string + duration int + ignoreTLS bool + method string + numberOfRequests int + requestsFromFile string + suppressOutput bool + url string + wait int +} + +func (configuration *Configuration) validate() error { + + if configuration.concurrency < 1 || configuration.numberOfRequests == 0 { + return errors.New("invalid concurrency level or number of requests") + } + + return nil +} diff --git a/count_worker.go b/count_worker.go index dc7c428..16812c7 100644 --- a/count_worker.go +++ b/count_worker.go @@ -14,7 +14,7 @@ func newCountWorker(requests <-chan bool, results chan<- HTTPResult, done chan<- return &countWorker{worker} } -func (worker *countWorker) sendRequest(request preloadedRequest) { +func (worker *countWorker) sendRequest(request preLoadedRequest) { req := fasthttp.AcquireRequest() req.SetRequestURI(request.url) req.Header.SetMethod(request.method) @@ -27,7 +27,7 @@ func (worker *countWorker) sendRequest(request preloadedRequest) { worker.finish() } -func (worker *countWorker) sendRequests(requests []preloadedRequest) { +func (worker *countWorker) sendRequests(requests []preLoadedRequest) { totalPremadeRequests := len(requests) for range worker.requests { diff --git a/csv_parsing.go b/csv_parsing.go index 84e0787..c113ec1 100644 --- a/csv_parsing.go +++ b/csv_parsing.go @@ -17,7 +17,7 @@ func extractHeaders(rawHeaders string) []string { return nil } -func preloadRequestsFromFile(filename string) ([]preloadedRequest, error) { +func preLoadRequestsFromFile(filename string) ([]preLoadedRequest, error) { file, err := os.Open(filename) if err != nil { @@ -25,7 +25,7 @@ func preloadRequestsFromFile(filename string) ([]preloadedRequest, error) { } reader := csv.NewReader(bufio.NewReader(file)) - var requests []preloadedRequest + var requests []preLoadedRequest for { record, err := reader.Read() @@ -64,7 +64,7 @@ func preloadRequestsFromFile(filename string) ([]preloadedRequest, error) { } } - requests = append(requests, preloadedRequest{method, url, body, headers}) + requests = append(requests, preLoadedRequest{method, url, body, headers}) } return requests, nil diff --git a/timed_worker.go b/timed_worker.go index 205eaf5..4e992c9 100644 --- a/timed_worker.go +++ b/timed_worker.go @@ -16,7 +16,7 @@ func newTimedWorker(requests <-chan bool, results chan<- HTTPResult, done chan<- return &timedWorker{worker, durationToRun} } -func (worker timedWorker) sendRequest(request preloadedRequest) { +func (worker timedWorker) sendRequest(request preLoadedRequest) { req := fasthttp.AcquireRequest() req.SetRequestURI(request.url) req.Header.SetMethod(request.method) @@ -35,7 +35,7 @@ func (worker timedWorker) sendRequest(request preloadedRequest) { worker.finish() } -func (worker timedWorker) sendRequests(requests []preloadedRequest) { +func (worker timedWorker) sendRequests(requests []preLoadedRequest) { totalPremadeRequests := len(requests) startTime := time.Now() diff --git a/worker.go b/worker.go index 8207412..2724a04 100644 --- a/worker.go +++ b/worker.go @@ -14,8 +14,8 @@ type worker struct { } type workable interface { - sendRequests(requests []preloadedRequest) - sendRequest(request preloadedRequest) + sendRequests(requests []preLoadedRequest) + sendRequest(request preLoadedRequest) setCustomClient(client *fasthttp.Client) } @@ -50,8 +50,8 @@ func (worker *worker) performRequest(req *fasthttp.Request, resp *fasthttp.Respo return false } -func buildRequest(requests []preloadedRequest, totalPremadeRequests int) (*fasthttp.Request, *fasthttp.Response) { - var currentReq preloadedRequest +func buildRequest(requests []preLoadedRequest, totalPremadeRequests int) (*fasthttp.Request, *fasthttp.Response) { + var currentReq preLoadedRequest currentReq = requests[rand.Intn(totalPremadeRequests)] req := fasthttp.AcquireRequest()