From 88e1d6b126ad38809503893ad31b5bd64501078a Mon Sep 17 00:00:00 2001 From: eval-exec Date: Mon, 21 Mar 2022 22:36:39 +0800 Subject: [PATCH] using subtests for TestValidateSecureEndpoints() --- client/pkg/transport/tls_test.go | 72 ++++++++++++++++++++++++++------ 1 file changed, 59 insertions(+), 13 deletions(-) diff --git a/client/pkg/transport/tls_test.go b/client/pkg/transport/tls_test.go index e2747f87b..aa2615257 100644 --- a/client/pkg/transport/tls_test.go +++ b/client/pkg/transport/tls_test.go @@ -17,7 +17,7 @@ package transport import ( "net/http" "net/http/httptest" - "strings" + "reflect" "testing" ) @@ -33,18 +33,64 @@ func TestValidateSecureEndpoints(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(remoteAddr)) defer srv.Close() - insecureEps := []string{ - "http://" + srv.Listener.Addr().String(), - "invalid remote address", + tests := map[string]struct { + endPoints []string + expectedEndpoints []string + expectedErr bool + }{ + "invalidEndPoints": { + endPoints: []string{ + "invalid endpoint", + }, + expectedEndpoints: nil, + expectedErr: true, + }, + "insecureEndpoints": { + endPoints: []string{ + "http://127.0.0.1:8000", + "http://" + srv.Listener.Addr().String(), + }, + expectedEndpoints: nil, + expectedErr: true, + }, + "secureEndPoints": { + endPoints: []string{ + "https://" + srv.Listener.Addr().String(), + }, + expectedEndpoints: []string{ + "https://" + srv.Listener.Addr().String(), + }, + expectedErr: false, + }, + "mixEndPoints": { + endPoints: []string{ + "https://" + srv.Listener.Addr().String(), + "http://" + srv.Listener.Addr().String(), + "invalid end points", + }, + expectedEndpoints: []string{ + "https://" + srv.Listener.Addr().String(), + }, + expectedErr: true, + }, } - if _, err := ValidateSecureEndpoints(*tlsInfo, insecureEps); err == nil || !strings.Contains(err.Error(), "is insecure") { - t.Error("validate secure endpoints should fail") - } - - secureEps := []string{ - "https://" + srv.Listener.Addr().String(), - } - if _, err := ValidateSecureEndpoints(*tlsInfo, secureEps); err != nil { - t.Error("validate secure endpoints should succeed") + for name, test := range tests { + t.Run(name, func(t *testing.T) { + secureEps, err := ValidateSecureEndpoints(*tlsInfo, test.endPoints) + if test.expectedErr && err == nil { + t.Errorf("expected error") + } + if !test.expectedErr && err != nil { + t.Errorf("unexpected error: %v", err) + } + if err == nil && !test.expectedErr { + if len(secureEps) != len(test.expectedEndpoints) { + t.Errorf("expected %v endpoints, got %v", len(test.expectedEndpoints), len(secureEps)) + } + if !reflect.DeepEqual(test.expectedEndpoints, secureEps) { + t.Errorf("expected endpoints %v, got %v", test.expectedEndpoints, secureEps) + } + } + }) } }