diff --git a/cmd/kots/cli/get-joincommand.go b/cmd/kots/cli/get-joincommand.go new file mode 100644 index 0000000000..554739e5fe --- /dev/null +++ b/cmd/kots/cli/get-joincommand.go @@ -0,0 +1,192 @@ +package cli + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + + "github.com/replicatedhq/kots/pkg/api/handlers/types" + "github.com/replicatedhq/kots/pkg/auth" + "github.com/replicatedhq/kots/pkg/k8sutil" + "github.com/spf13/cobra" + "github.com/spf13/viper" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/client-go/kubernetes" +) + +func GetJoinCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "join-command", + Short: "Get embedded cluster join command", + Long: "", + SilenceUsage: false, + SilenceErrors: false, + Hidden: true, + PreRun: func(cmd *cobra.Command, args []string) { + viper.BindPFlags(cmd.Flags()) + }, + RunE: func(cmd *cobra.Command, args []string) error { + v := viper.GetViper() + + clientset, err := k8sutil.GetClientset() + if err != nil { + return fmt.Errorf("failed to get clientset: %w", err) + } + + namespace, err := getNamespaceOrDefault(v.GetString("namespace")) + if err != nil { + return fmt.Errorf("failed to get namespace: %w", err) + } + + joinCmd, err := getJoinCommandCmd(cmd.Context(), clientset, namespace) + if err != nil { + return err + } + + format := v.GetString("output") + if format == "string" || format == "" { + fmt.Println(strings.Join(joinCmd, " ")) + return nil + } else if format == "json" { + type joinCommandResponse struct { + Command []string `json:"command"` + } + joinCmdResponse := joinCommandResponse{ + Command: joinCmd, + } + b, err := json.Marshal(joinCmdResponse) + if err != nil { + return fmt.Errorf("failed to marshal join command: %w", err) + } + fmt.Println(string(b)) + return nil + } + + return fmt.Errorf("invalid output format: %s", format) + }, + } + cmd.Flags().StringP("output", "o", "", "output format (currently supported: json)") + + return cmd +} + +func getJoinCommandCmd(ctx context.Context, clientset kubernetes.Interface, namespace string) ([]string, error) { + // determine the IP address and port of the kotsadm service + // this only runs inside an embedded cluster and so we don't need to setup port forwarding + svc, err := clientset.CoreV1().Services(namespace).Get(ctx, "kotsadm", metav1.GetOptions{}) + if err != nil { + return nil, fmt.Errorf("unable to get kotsadm service: %w", err) + } + kotsadmIP := svc.Spec.ClusterIP + if kotsadmIP == "" { + return nil, fmt.Errorf("kotsadm service ip was empty") + } + + if len(svc.Spec.Ports) == 0 { + return nil, fmt.Errorf("kotsadm service ports were empty") + } + kotsadmPort := svc.Spec.Ports[0].Port + + authSlug, err := auth.GetOrCreateAuthSlug(clientset, namespace) + if err != nil { + return nil, fmt.Errorf("failed to get kotsadm auth slug: %w", err) + } + + url := fmt.Sprintf("http://%s:%d/api/v1/embedded-cluster/roles", kotsadmIP, kotsadmPort) + roles, err := getRoles(url, authSlug) + if err != nil { + return nil, fmt.Errorf("failed to get roles: %w", err) + } + + controllerRole := roles.ControllerRoleName + if controllerRole == "" && len(roles.Roles) > 0 { + controllerRole = roles.Roles[0] + } + if controllerRole == "" { + return nil, fmt.Errorf("unable to determine controller role name") + } + + // get a join command with the controller role with a post to /api/v1/embedded-cluster/generate-node-join-command + url = fmt.Sprintf("http://%s:%d/api/v1/embedded-cluster/generate-node-join-command", kotsadmIP, kotsadmPort) + joinCommand, err := getJoinCommand(url, authSlug, []string{controllerRole}) + if err != nil { + return nil, fmt.Errorf("failed to get join command: %w", err) + } + + return joinCommand.Command, nil +} + +// determine the embedded cluster roles list from /api/v1/embedded-cluster/roles +func getRoles(url string, authSlug string) (*types.GetEmbeddedClusterRolesResponse, error) { + newReq, err := http.NewRequest("GET", url, nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + newReq.Header.Add("Content-Type", "application/json") + newReq.Header.Add("Authorization", authSlug) + + resp, err := http.DefaultClient.Do(newReq) + if err != nil { + return nil, fmt.Errorf("failed to execute request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + + b, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + roles := &types.GetEmbeddedClusterRolesResponse{} + if err := json.Unmarshal(b, roles); err != nil { + return nil, fmt.Errorf("failed to unmarshal roles: %w", err) + } + + return roles, nil +} + +func getJoinCommand(url string, authSlug string, roles []string) (*types.GenerateEmbeddedClusterNodeJoinCommandResponse, error) { + payload := types.GenerateEmbeddedClusterNodeJoinCommandRequest{ + Roles: roles, + } + b, err := json.Marshal(payload) + if err != nil { + return nil, fmt.Errorf("failed to marshal roles: %w", err) + } + + newReq, err := http.NewRequest("POST", url, bytes.NewBuffer(b)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + newReq.Header.Add("Content-Type", "application/json") + newReq.Header.Add("Authorization", authSlug) + + resp, err := http.DefaultClient.Do(newReq) + if err != nil { + return nil, fmt.Errorf("failed to execute request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + + fullResponse, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + joinCommand := &types.GenerateEmbeddedClusterNodeJoinCommandResponse{} + if err := json.Unmarshal(fullResponse, joinCommand); err != nil { + return nil, fmt.Errorf("failed to unmarshal roles: %w", err) + } + + return joinCommand, nil +} diff --git a/cmd/kots/cli/get-joincommand_test.go b/cmd/kots/cli/get-joincommand_test.go new file mode 100644 index 0000000000..579cf78dc1 --- /dev/null +++ b/cmd/kots/cli/get-joincommand_test.go @@ -0,0 +1,214 @@ +package cli + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "net/url" + "strconv" + "testing" + + "github.com/stretchr/testify/require" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/client-go/kubernetes/fake" +) + +func TestGetJoinCommand(t *testing.T) { + tests := []struct { + name string + service *corev1.Service + secret *corev1.Secret + handler http.HandlerFunc + expectedError string + expectedCmd []string + }{ + { + name: "successful join command generation", + service: &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: "kotsadm", + Namespace: "kotsadm", + }, + Spec: corev1.ServiceSpec{ + ClusterIP: "127.0.0.1", + Ports: []corev1.ServicePort{ + {}, + }, + }, + }, + secret: &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: "kotsadm-authstring", + Namespace: "kotsadm", + }, + Data: map[string][]byte{ + "kotsadm-authstring": []byte("test-auth-token"), + }, + }, + handler: func(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case "GET": + require.Equal(t, "/api/v1/embedded-cluster/roles", r.URL.Path) + require.Equal(t, "test-auth-token", r.Header.Get("Authorization")) + + response := map[string]interface{}{ + "roles": []string{"controller-role-name-normally-not-different", "worker"}, + "controllerRoleName": "test-controller-role-name", + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(response) + case "POST": + require.Equal(t, "/api/v1/embedded-cluster/generate-node-join-command", r.URL.Path) + require.Equal(t, "test-auth-token", r.Header.Get("Authorization")) + require.Equal(t, "application/json", r.Header.Get("Content-Type")) + + var requestBody struct { + Roles []string `json:"roles"` + } + err := json.NewDecoder(r.Body).Decode(&requestBody) + require.NoError(t, err) + require.Equal(t, []string{"test-controller-role-name"}, requestBody.Roles) + + response := map[string][]string{ + "command": {"embedded-cluster", "join", "--token", "test-token"}, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(response) + } + }, + expectedCmd: []string{"embedded-cluster", "join", "--token", "test-token"}, + }, + { + name: "missing service", + service: nil, + expectedError: "unable to get kotsadm service", + }, + { + name: "server returns error status when fetching roles", + service: &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: "kotsadm", + Namespace: "kotsadm", + }, + Spec: corev1.ServiceSpec{ + ClusterIP: "127.0.0.1", + Ports: []corev1.ServicePort{ + {}, + }, + }, + }, + secret: &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: "kotsadm-authstring", + Namespace: "kotsadm", + }, + Data: map[string][]byte{ + "kotsadm-authstring": []byte("test-auth-token"), + }, + }, + handler: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + response := map[string]string{ + "error": "internal server error", + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(response) + }, + expectedError: "failed to get roles: unexpected status code: 500", + }, + { + name: "server returns error status when creating token", + service: &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: "kotsadm", + Namespace: "kotsadm", + }, + Spec: corev1.ServiceSpec{ + ClusterIP: "127.0.0.1", + Ports: []corev1.ServicePort{ + {}, + }, + }, + }, + secret: &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: "kotsadm-authstring", + Namespace: "kotsadm", + }, + Data: map[string][]byte{ + "kotsadm-authstring": []byte("test-auth-token"), + }, + }, + + handler: func(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case "GET": + require.Equal(t, "/api/v1/embedded-cluster/roles", r.URL.Path) + require.Equal(t, "test-auth-token", r.Header.Get("Authorization")) + + response := map[string]interface{}{ + "roles": []string{"controller-role-name-normally-not-different", "worker"}, + "controllerRoleName": "test-controller-role-name", + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(response) + case "POST": + w.WriteHeader(http.StatusInternalServerError) + response := map[string]string{ + "error": "internal server error", + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(response) + } + }, + expectedError: "failed to get join command: unexpected status code: 500", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + // Create a test server if we have a handler + var server *httptest.Server + if test.handler != nil { + server = httptest.NewServer(test.handler) + defer server.Close() + + // Update the service IP and port to match the test server + serverURL, err := url.Parse(server.URL) + require.NoError(t, err) + + host := serverURL.Hostname() + port, err := strconv.ParseInt(serverURL.Port(), 10, 32) + require.NoError(t, err) + + test.service.Spec.ClusterIP = host + test.service.Spec.Ports[0].Port = int32(port) + } + + // Create fake client with test objects + var objects []runtime.Object + if test.service != nil { + objects = append(objects, test.service) + } + if test.secret != nil { + objects = append(objects, test.secret) + } + fakeClient := fake.NewSimpleClientset(objects...) + + // Call GetJoinCommand + cmd, err := getJoinCommandCmd(context.Background(), fakeClient, "kotsadm") + + // Verify results + if test.expectedError != "" { + require.Error(t, err) + require.Contains(t, err.Error(), test.expectedError) + } else { + require.NoError(t, err) + require.Equal(t, test.expectedCmd, cmd) + } + }) + } +} diff --git a/cmd/kots/cli/get.go b/cmd/kots/cli/get.go index f6bc6c49de..2d21831c87 100644 --- a/cmd/kots/cli/get.go +++ b/cmd/kots/cli/get.go @@ -34,6 +34,7 @@ kubectl kots get apps`, cmd.AddCommand(GetVersionsCmd()) cmd.AddCommand(GetConfigCmd()) cmd.AddCommand(GetRestoresCmd()) + cmd.AddCommand(GetJoinCmd()) return cmd } diff --git a/pkg/api/handlers/types/types.go b/pkg/api/handlers/types/types.go index 65a0d497ce..e72d03da18 100644 --- a/pkg/api/handlers/types/types.go +++ b/pkg/api/handlers/types/types.go @@ -98,3 +98,16 @@ type ResponsePendingApp struct { LicenseData string `json:"licenseData"` NeedsRegistry bool `json:"needsRegistry"` } + +type GetEmbeddedClusterRolesResponse struct { + Roles []string `json:"roles"` + ControllerRoleName string `json:"controllerRoleName"` +} + +type GenerateEmbeddedClusterNodeJoinCommandRequest struct { + Roles []string `json:"roles"` +} + +type GenerateEmbeddedClusterNodeJoinCommandResponse struct { + Command []string `json:"command"` +} diff --git a/pkg/embeddedcluster/roles.go b/pkg/embeddedcluster/roles.go index 74679a25fe..67b666eeb5 100644 --- a/pkg/embeddedcluster/roles.go +++ b/pkg/embeddedcluster/roles.go @@ -16,10 +16,10 @@ const DEFAULT_CONTROLLER_ROLE_NAME = "controller" var labelValueRegex = regexp.MustCompile(`[^a-zA-Z0-9-_.]+`) // GetRoles will get a list of role names -func GetRoles(ctx context.Context, kbClient kbclient.Client) ([]string, error) { +func GetRoles(ctx context.Context, kbClient kbclient.Client) ([]string, string, error) { config, err := ClusterConfig(ctx, kbClient) if err != nil { - return nil, fmt.Errorf("failed to get cluster config: %w", err) + return nil, "", fmt.Errorf("failed to get cluster config: %w", err) } if config == nil { @@ -29,7 +29,9 @@ func GetRoles(ctx context.Context, kbClient kbclient.Client) ([]string, error) { // determine role names roles := []string{} + controllerRoleName := DEFAULT_CONTROLLER_ROLE_NAME if config.Roles.Controller.Name != "" { + controllerRoleName = config.Roles.Controller.Name roles = append(roles, config.Roles.Controller.Name) } else { roles = append(roles, DEFAULT_CONTROLLER_ROLE_NAME) @@ -41,7 +43,7 @@ func GetRoles(ctx context.Context, kbClient kbclient.Client) ([]string, error) { } } - return roles, nil + return roles, controllerRoleName, nil } // ControllerRoleName determines the name for the 'controller' role diff --git a/pkg/handlers/embedded_cluster_get.go b/pkg/handlers/embedded_cluster_get.go index 624835058e..066991ba0e 100644 --- a/pkg/handlers/embedded_cluster_get.go +++ b/pkg/handlers/embedded_cluster_get.go @@ -4,16 +4,13 @@ import ( "net/http" "github.com/gorilla/mux" + "github.com/replicatedhq/kots/pkg/api/handlers/types" "github.com/replicatedhq/kots/pkg/embeddedcluster" "github.com/replicatedhq/kots/pkg/k8sutil" "github.com/replicatedhq/kots/pkg/logger" "github.com/replicatedhq/kots/pkg/util" ) -type GetEmbeddedClusterRolesResponse struct { - Roles []string `json:"roles"` -} - func (h *Handler) GetEmbeddedClusterNodes(w http.ResponseWriter, r *http.Request) { if !util.IsEmbeddedCluster() { logger.Errorf("not an embedded cluster") @@ -75,11 +72,11 @@ func (h *Handler) GetEmbeddedClusterRoles(w http.ResponseWriter, r *http.Request return } - roles, err := embeddedcluster.GetRoles(r.Context(), kbClient) + roles, controllerRoleName, err := embeddedcluster.GetRoles(r.Context(), kbClient) if err != nil { logger.Error(err) w.WriteHeader(http.StatusInternalServerError) return } - JSON(w, http.StatusOK, GetEmbeddedClusterRolesResponse{Roles: roles}) + JSON(w, http.StatusOK, types.GetEmbeddedClusterRolesResponse{Roles: roles, ControllerRoleName: controllerRoleName}) } diff --git a/pkg/handlers/embedded_cluster_node_join_command.go b/pkg/handlers/embedded_cluster_node_join_command.go index db16fa4870..9c2c2d3304 100644 --- a/pkg/handlers/embedded_cluster_node_join_command.go +++ b/pkg/handlers/embedded_cluster_node_join_command.go @@ -8,6 +8,7 @@ import ( "github.com/google/uuid" "github.com/replicatedhq/embedded-cluster/kinds/types/join" + "github.com/replicatedhq/kots/pkg/api/handlers/types" "github.com/replicatedhq/kots/pkg/embeddedcluster" "github.com/replicatedhq/kots/pkg/k8sutil" "github.com/replicatedhq/kots/pkg/kotsutil" @@ -16,14 +17,6 @@ import ( "github.com/replicatedhq/kots/pkg/util" ) -type GenerateEmbeddedClusterNodeJoinCommandResponse struct { - Command []string `json:"command"` -} - -type GenerateEmbeddedClusterNodeJoinCommandRequest struct { - Roles []string `json:"roles"` -} - func (h *Handler) GenerateEmbeddedClusterNodeJoinCommand(w http.ResponseWriter, r *http.Request) { if !util.IsEmbeddedCluster() { logger.Errorf("not an embedded cluster") @@ -31,7 +24,7 @@ func (h *Handler) GenerateEmbeddedClusterNodeJoinCommand(w http.ResponseWriter, return } - generateEmbeddedClusterNodeJoinCommandRequest := GenerateEmbeddedClusterNodeJoinCommandRequest{} + generateEmbeddedClusterNodeJoinCommandRequest := types.GenerateEmbeddedClusterNodeJoinCommandRequest{} if err := json.NewDecoder(r.Body).Decode(&generateEmbeddedClusterNodeJoinCommandRequest); err != nil { logger.Error(fmt.Errorf("failed to decode request body: %w", err)) w.WriteHeader(http.StatusBadRequest) @@ -72,7 +65,7 @@ func (h *Handler) GenerateEmbeddedClusterNodeJoinCommand(w http.ResponseWriter, return } - JSON(w, http.StatusOK, GenerateEmbeddedClusterNodeJoinCommandResponse{ + JSON(w, http.StatusOK, types.GenerateEmbeddedClusterNodeJoinCommandResponse{ Command: []string{nodeJoinCommand}, }) }