diff --git a/.gitignore b/.gitignore index c28db86..3d2eeb7 100644 --- a/.gitignore +++ b/.gitignore @@ -49,4 +49,3 @@ Thumbs.db /server server.log -whatshooked diff --git a/cmd/cli/client.go b/cmd/cli/client.go index baf45d7..bc1a24a 100644 --- a/cmd/cli/client.go +++ b/cmd/cli/client.go @@ -102,6 +102,16 @@ func decodeJSON(resp *http.Response, target interface{}) error { return json.NewDecoder(resp.Body).Decode(target) } +// serverAvailable checks if the server is reachable via the health endpoint. +func serverAvailable(client *Client) bool { + resp, err := client.client.Get(client.baseURL + "/health") + if err != nil { + return false + } + resp.Body.Close() + return resp.StatusCode < 500 +} + // checkError prints error and exits if error is not nil func checkError(err error) { if err != nil { diff --git a/cmd/cli/commands_accounts.go b/cmd/cli/commands_accounts.go index 3f29287..102feef 100644 --- a/cmd/cli/commands_accounts.go +++ b/cmd/cli/commands_accounts.go @@ -1,13 +1,28 @@ package main import ( + "context" + "encoding/json" "fmt" + "os" + "text/tabwriter" + "time" - "git.warky.dev/wdevs/whatshooked/pkg/config" + "git.warky.dev/wdevs/whatshooked/pkg/models" + "git.warky.dev/wdevs/whatshooked/pkg/storage" + resolvespec_common "github.com/bitechdev/ResolveSpec/pkg/spectypes" + "github.com/google/uuid" "github.com/spf13/cobra" ) -// accountsCmd is the parent command for account management +var ( + accountUser string + accountPhoneNumber string + accountType string + accountSessionPath string + accountDisplayName string +) + var accountsCmd = &cobra.Command{ Use: "accounts", Short: "Manage WhatsApp accounts", @@ -35,17 +50,47 @@ var accountsAddCmd = &cobra.Command{ }, } +var accountsRemoveCmd = &cobra.Command{ + Use: "remove ", + Short: "Remove a WhatsApp account by ID", + Args: cobra.ExactArgs(1), + Run: func(cmd *cobra.Command, args []string) { + client := NewClient(cliConfig) + removeAccount(client, args[0]) + }, +} + func init() { + accountsAddCmd.Flags().StringVarP(&accountPhoneNumber, "phone", "p", "", "Phone number (with country code)") + accountsAddCmd.Flags().StringVarP(&accountType, "type", "t", "whatsmeow", "Account type (whatsmeow/business-api)") + accountsAddCmd.Flags().StringVarP(&accountSessionPath, "session-path", "s", "", "Session path (auto-generated if omitted)") + accountsAddCmd.Flags().StringVarP(&accountDisplayName, "display-name", "d", "", "Display name") + accountsAddCmd.Flags().StringVarP(&accountUser, "user", "u", "", "Owner username for DB mode (default: first admin)") + accountsCmd.AddCommand(accountsListCmd) accountsCmd.AddCommand(accountsAddCmd) + accountsCmd.AddCommand(accountsRemoveCmd) } func listAccounts(client *Client) { + if serverAvailable(client) { + listAccountsHTTP(client) + } else { + fmt.Println("[server unavailable, reading from database]") + if !tryInitDB() { + fmt.Println("Error: server unreachable and no database config found. Use --server-config to specify config path.") + return + } + listAccountsDB() + } +} + +func listAccountsHTTP(client *Client) { resp, err := client.Get("/api/accounts") checkError(err) defer resp.Body.Close() - var accounts []config.WhatsAppConfig + var accounts []map[string]interface{} checkError(decodeJSON(resp, &accounts)) if len(accounts) == 0 { @@ -53,37 +98,140 @@ func listAccounts(client *Client) { return } - fmt.Printf("Configured accounts (%d):\n\n", len(accounts)) + w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0) + fmt.Fprintln(w, "ID\tPHONE\tTYPE\tSTATUS\tACTIVE") for _, acc := range accounts { - fmt.Printf("ID: %s\n", acc.ID) - fmt.Printf("Phone Number: %s\n", acc.PhoneNumber) - fmt.Printf("Session Path: %s\n", acc.SessionPath) - fmt.Println() + fmt.Fprintf(w, "%v\t%v\t%v\t%v\t%v\n", + acc["id"], acc["phone_number"], acc["account_type"], acc["status"], acc["active"]) } + w.Flush() +} + +func listAccountsDB() { + var accounts []models.ModelPublicWhatsappAccount + err := storage.DB.NewSelect().Model(&accounts).OrderExpr("created_at ASC").Scan(context.Background()) + checkError(err) + + if len(accounts) == 0 { + fmt.Println("No accounts configured") + return + } + + w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0) + fmt.Fprintln(w, "ID\tPHONE\tTYPE\tSTATUS\tACTIVE") + for _, acc := range accounts { + active := "yes" + if !acc.Active { + active = "no" + } + fmt.Fprintf(w, "%s\t%s\t%s\t%s\t%s\n", + acc.ID.String(), + acc.PhoneNumber.String(), + acc.AccountType.String(), + acc.Status.String(), + active, + ) + } + w.Flush() } func addAccount(client *Client) { - var account config.WhatsAppConfig - - fmt.Print("Account ID: ") - if _, err := fmt.Scanln(&account.ID); err != nil { - checkError(fmt.Errorf("error reading account ID: %v", err)) + phone := accountPhoneNumber + if phone == "" { + phone = promptRequired("Phone Number (with country code)") } - fmt.Print("Phone Number (with country code): ") - if _, err := fmt.Scanln(&account.PhoneNumber); err != nil { - checkError(fmt.Errorf("error reading phone number: %v", err)) + acctType := accountType + if acctType == "" { + acctType = promptLine("Account type (whatsmeow/business-api)", "whatsmeow") } - fmt.Print("Session Path: ") - if _, err := fmt.Scanln(&account.SessionPath); err != nil { - checkError(fmt.Errorf("error reading session path: %v", err)) + displayName := accountDisplayName + if displayName == "" { + displayName = promptLine("Display name (optional)", "") } - resp, err := client.Post("/api/accounts/add", account) + if serverAvailable(client) { + addAccountHTTP(client, phone, acctType, displayName) + } else { + fmt.Println("[server unavailable, writing to database]") + if !tryInitDB() { + fmt.Println("Error: server unreachable and no database config found. Use --server-config to specify config path.") + return + } + addAccountDB(phone, acctType, displayName) + } +} + +func addAccountHTTP(client *Client, phone, acctType, displayName string) { + payload := map[string]interface{}{ + "phone_number": phone, + "account_type": acctType, + "display_name": displayName, + } + resp, err := client.Post("/api/accounts/add", payload) checkError(err) defer resp.Body.Close() - fmt.Println("Account added successfully") fmt.Println("Check server logs for QR code to pair the device") } + +func addAccountDB(phone, acctType, displayName string) { + userID := dbOwnerUserID(accountUser) + if userID == "" { + fmt.Println("Error: no users found in database. Create a user first with: users add") + return + } + + id := uuid.New().String() + sessionPath := accountSessionPath + if sessionPath == "" { + sessionPath = fmt.Sprintf("./sessions/%s", id) + } + + now := time.Now() + var cfgJSON string + if acctType == "business-api" { + b, _ := json.Marshal(map[string]string{}) + cfgJSON = string(b) + } + + account := &models.ModelPublicWhatsappAccount{ + ID: resolvespec_common.NewSqlString(id), + PhoneNumber: resolvespec_common.NewSqlString(phone), + AccountType: resolvespec_common.NewSqlString(acctType), + DisplayName: resolvespec_common.NewSqlString(displayName), + SessionPath: resolvespec_common.NewSqlString(sessionPath), + Config: resolvespec_common.NewSqlString(cfgJSON), + Status: resolvespec_common.NewSqlString("disconnected"), + UserID: resolvespec_common.NewSqlString(userID), + Active: true, + CreatedAt: resolvespec_common.NewSqlTimeStamp(now), + UpdatedAt: resolvespec_common.NewSqlTimeStamp(now), + } + + repo := storage.NewWhatsAppAccountRepository(storage.DB) + checkError(repo.Create(context.Background(), account)) + + fmt.Printf("Account '%s' added (ID: %s)\n", phone, id) + fmt.Println("Start the server to connect and pair the device") +} + +func removeAccount(client *Client, id string) { + if serverAvailable(client) { + resp, err := client.Post("/api/accounts/remove", map[string]string{"id": id}) + checkError(err) + defer resp.Body.Close() + fmt.Println("Account removed successfully") + } else { + fmt.Println("[server unavailable, removing from database]") + if !tryInitDB() { + fmt.Println("Error: server unreachable and no database config found. Use --server-config to specify config path.") + return + } + repo := storage.NewWhatsAppAccountRepository(storage.DB) + checkError(repo.Delete(context.Background(), id)) + fmt.Printf("Account '%s' removed\n", id) + } +} + diff --git a/cmd/cli/commands_hooks.go b/cmd/cli/commands_hooks.go index 88d5567..197e0fb 100644 --- a/cmd/cli/commands_hooks.go +++ b/cmd/cli/commands_hooks.go @@ -1,16 +1,23 @@ package main import ( - "bufio" + "context" + "encoding/json" "fmt" "os" "strings" + "text/tabwriter" + "time" - "git.warky.dev/wdevs/whatshooked/pkg/config" + "git.warky.dev/wdevs/whatshooked/pkg/models" + "git.warky.dev/wdevs/whatshooked/pkg/storage" + resolvespec_common "github.com/bitechdev/ResolveSpec/pkg/spectypes" + "github.com/google/uuid" "github.com/spf13/cobra" ) -// hooksCmd is the parent command for hook management +var hookUser string + var hooksCmd = &cobra.Command{ Use: "hooks", Short: "Manage webhooks", @@ -40,7 +47,7 @@ var hooksAddCmd = &cobra.Command{ var hooksRemoveCmd = &cobra.Command{ Use: "remove ", - Short: "Remove a hook", + Short: "Remove a hook by ID", Args: cobra.ExactArgs(1), Run: func(cmd *cobra.Command, args []string) { client := NewClient(cliConfig) @@ -49,17 +56,31 @@ var hooksRemoveCmd = &cobra.Command{ } func init() { + hooksAddCmd.Flags().StringVarP(&hookUser, "user", "u", "", "Owner username for DB mode (default: first admin)") hooksCmd.AddCommand(hooksListCmd) hooksCmd.AddCommand(hooksAddCmd) hooksCmd.AddCommand(hooksRemoveCmd) } func listHooks(client *Client) { + if serverAvailable(client) { + listHooksHTTP(client) + } else { + fmt.Println("[server unavailable, reading from database]") + if !tryInitDB() { + fmt.Println("Error: server unreachable and no database config found. Use --server-config to specify config path.") + return + } + listHooksDB() + } +} + +func listHooksHTTP(client *Client) { resp, err := client.Get("/api/hooks") checkError(err) defer resp.Body.Close() - var hooks []config.Hook + var hooks []map[string]interface{} checkError(decodeJSON(resp, &hooks)) if len(hooks) == 0 { @@ -67,98 +88,153 @@ func listHooks(client *Client) { return } - fmt.Printf("Configured hooks (%d):\n\n", len(hooks)) - for _, hook := range hooks { - status := "inactive" - if hook.Active { - status = "active" - } - fmt.Printf("ID: %s\n", hook.ID) - fmt.Printf("Name: %s\n", hook.Name) - fmt.Printf("URL: %s\n", hook.URL) - fmt.Printf("Method: %s\n", hook.Method) - fmt.Printf("Status: %s\n", status) - if len(hook.Events) > 0 { - fmt.Printf("Events: %v\n", hook.Events) - } else { - fmt.Printf("Events: all (no filter)\n") - } - if hook.Description != "" { - fmt.Printf("Description: %s\n", hook.Description) - } - fmt.Println() + w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0) + fmt.Fprintln(w, "ID\tNAME\tURL\tMETHOD\tACTIVE") + for _, h := range hooks { + fmt.Fprintf(w, "%v\t%v\t%v\t%v\t%v\n", + h["id"], h["name"], h["url"], h["method"], h["active"]) } + w.Flush() +} + +func listHooksDB() { + var hooks []models.ModelPublicHook + err := storage.DB.NewSelect().Model(&hooks).OrderExpr("created_at ASC").Scan(context.Background()) + checkError(err) + + if len(hooks) == 0 { + fmt.Println("No hooks configured") + return + } + + w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0) + fmt.Fprintln(w, "ID\tNAME\tURL\tMETHOD\tACTIVE\tEVENTS") + for _, h := range hooks { + active := "yes" + if !h.Active { + active = "no" + } + events := parseEventsJSON(h.Events.String()) + fmt.Fprintf(w, "%s\t%s\t%s\t%s\t%s\t%s\n", + h.ID.String(), + h.Name.String(), + h.URL.String(), + h.Method.String(), + active, + events, + ) + } + w.Flush() } func addHook(client *Client) { - var hook config.Hook - scanner := bufio.NewScanner(os.Stdin) + name := promptRequired("Hook Name") + url := promptRequired("Webhook URL") + method := promptLine("HTTP Method", "POST") - fmt.Print("Hook ID: ") - if _, err := fmt.Scanln(&hook.ID); err != nil { - checkError(fmt.Errorf("error reading hook ID: %v", err)) - } - - fmt.Print("Hook Name: ") - if _, err := fmt.Scanln(&hook.Name); err != nil { - checkError(fmt.Errorf("error reading hook name: %v", err)) - } - - fmt.Print("Webhook URL: ") - if _, err := fmt.Scanln(&hook.URL); err != nil { - checkError(fmt.Errorf("error reading webhook URL: %v", err)) - } - - fmt.Print("HTTP Method (POST): ") - if _, err := fmt.Scanln(&hook.Method); err == nil { - // Successfully read input - fmt.Printf("Selected Method %s", hook.Method) - } - if hook.Method == "" { - hook.Method = "POST" - } - - // Prompt for events with helpful examples fmt.Println("\nAvailable events:") - fmt.Println(" WhatsApp: whatsapp.connected, whatsapp.disconnected, whatsapp.qr.code") - fmt.Println(" Messages: message.received, message.sent, message.delivered, message.read") - fmt.Println(" Hooks: hook.triggered, hook.success, hook.failed") - fmt.Print("\nEvents (comma-separated, or press Enter for all): ") + fmt.Println(" whatsapp.connected, whatsapp.disconnected, whatsapp.qr.code") + fmt.Println(" message.received, message.sent, message.delivered, message.read") + fmt.Println(" hook.triggered, hook.success, hook.failed") + eventsRaw := promptLine("\nEvents (comma-separated, or Enter for all)", "") + description := promptLine("Description (optional)", "") - scanner.Scan() - eventsInput := strings.TrimSpace(scanner.Text()) - - if eventsInput != "" { - // Split by comma and trim whitespace - eventsList := strings.Split(eventsInput, ",") - hook.Events = make([]string, 0, len(eventsList)) - for _, event := range eventsList { - trimmed := strings.TrimSpace(event) - if trimmed != "" { - hook.Events = append(hook.Events, trimmed) + var events []string + if eventsRaw != "" { + for _, e := range strings.Split(eventsRaw, ",") { + if t := strings.TrimSpace(e); t != "" { + events = append(events, t) } } } - fmt.Print("\nDescription (optional): ") - scanner.Scan() - hook.Description = strings.TrimSpace(scanner.Text()) + if serverAvailable(client) { + addHookHTTP(client, name, url, method, description, events) + } else { + fmt.Println("[server unavailable, writing to database]") + if !tryInitDB() { + fmt.Println("Error: server unreachable and no database config found. Use --server-config to specify config path.") + return + } + addHookDB(name, url, method, description, events) + } +} - hook.Active = true - - resp, err := client.Post("/api/hooks/add", hook) +func addHookHTTP(client *Client, name, url, method, description string, events []string) { + payload := map[string]interface{}{ + "name": name, + "url": url, + "method": method, + "description": description, + "events": events, + "active": true, + } + resp, err := client.Post("/api/hooks/add", payload) checkError(err) defer resp.Body.Close() - fmt.Println("Hook added successfully") } -func removeHook(client *Client, id string) { - req := map[string]string{"id": id} +func addHookDB(name, url, method, description string, events []string) { + userID := dbOwnerUserID(hookUser) + if userID == "" { + fmt.Println("Error: no users found in database. Create a user first with: users add") + return + } - resp, err := client.Post("/api/hooks/remove", req) - checkError(err) - defer resp.Body.Close() + eventsJSON := "[]" + if len(events) > 0 { + b, _ := json.Marshal(events) + eventsJSON = string(b) + } - fmt.Println("Hook removed successfully") + id := uuid.New().String() + now := time.Now() + hook := &models.ModelPublicHook{ + ID: resolvespec_common.NewSqlString(id), + Name: resolvespec_common.NewSqlString(name), + URL: resolvespec_common.NewSqlString(url), + Method: resolvespec_common.NewSqlString(method), + Description: resolvespec_common.NewSqlString(description), + Events: resolvespec_common.NewSqlString(eventsJSON), + UserID: resolvespec_common.NewSqlString(userID), + Active: true, + CreatedAt: resolvespec_common.NewSqlTimeStamp(now), + UpdatedAt: resolvespec_common.NewSqlTimeStamp(now), + } + + repo := storage.NewHookRepository(storage.DB) + checkError(repo.Create(context.Background(), hook)) + + fmt.Printf("Hook '%s' added (ID: %s)\n", name, id) +} + +func removeHook(client *Client, id string) { + if serverAvailable(client) { + resp, err := client.Post("/api/hooks/remove", map[string]string{"id": id}) + checkError(err) + defer resp.Body.Close() + fmt.Println("Hook removed successfully") + } else { + fmt.Println("[server unavailable, removing from database]") + if !tryInitDB() { + fmt.Println("Error: server unreachable and no database config found. Use --server-config to specify config path.") + return + } + repo := storage.NewHookRepository(storage.DB) + checkError(repo.Delete(context.Background(), id)) + fmt.Printf("Hook '%s' removed\n", id) + } +} + +// parseEventsJSON parses a JSON events string into a comma-separated display string. +func parseEventsJSON(raw string) string { + if raw == "" || raw == "[]" { + return "all" + } + var events []string + if err := json.Unmarshal([]byte(raw), &events); err != nil { + return raw + } + return strings.Join(events, ", ") } diff --git a/cmd/cli/commands_users.go b/cmd/cli/commands_users.go index 74c1c2e..111821a 100644 --- a/cmd/cli/commands_users.go +++ b/cmd/cli/commands_users.go @@ -4,6 +4,7 @@ import ( "bufio" "context" "fmt" + "io" "os" "path/filepath" "strings" @@ -22,6 +23,23 @@ import ( var serverConfigPath string +// flags for add +var ( + addUsername string + addEmail string + addFullName string + addPassword string + addRole string +) + +// flags for update +var ( + updateUsername string + updateEmail string + updateFullName string + updateRole string +) + var usersCmd = &cobra.Command{ Use: "users", Short: "Manage users (direct DB access)", @@ -50,11 +68,15 @@ var usersAddCmd = &cobra.Command{ } var usersSetPasswordCmd = &cobra.Command{ - Use: "set-password ", + Use: "set-password [password]", Short: "Change a user's password", - Args: cobra.ExactArgs(1), + Args: cobra.RangeArgs(1, 2), Run: func(cmd *cobra.Command, args []string) { - setPassword(args[0]) + if len(args) == 2 { + setPasswordDirect(args[0], args[1]) + } else { + setPassword(args[0]) + } }, } @@ -105,6 +127,18 @@ var usersRemoveCmd = &cobra.Command{ func init() { usersCmd.PersistentFlags().StringVar(&serverConfigPath, "server-config", "", "server config file (default: config.json or ~/.whatshooked/config.json)") + + usersAddCmd.Flags().StringVarP(&addUsername, "username", "u", "", "Username") + usersAddCmd.Flags().StringVarP(&addEmail, "email", "e", "", "Email address") + usersAddCmd.Flags().StringVarP(&addFullName, "full-name", "n", "", "Full name") + usersAddCmd.Flags().StringVarP(&addPassword, "password", "p", "", "Password") + usersAddCmd.Flags().StringVarP(&addRole, "role", "r", "", "Role (admin/user)") + + usersUpdateCmd.Flags().StringVarP(&updateUsername, "username", "u", "", "New username") + usersUpdateCmd.Flags().StringVarP(&updateEmail, "email", "e", "", "New email address") + usersUpdateCmd.Flags().StringVarP(&updateFullName, "full-name", "n", "", "New full name") + usersUpdateCmd.Flags().StringVarP(&updateRole, "role", "r", "", "New role (admin/user)") + usersCmd.AddCommand(usersListCmd) usersCmd.AddCommand(usersAddCmd) usersCmd.AddCommand(usersUpdateCmd) @@ -132,6 +166,41 @@ func resolveServerConfigPath() string { return "" } +// tryInitDB attempts to initialize the DB, returning false if it cannot. +func tryInitDB() bool { + if storage.DB != nil { + return true + } + cfgPath := resolveServerConfigPath() + if cfgPath == "" { + return false + } + cfg, err := config.Load(cfgPath) + if err != nil { + return false + } + return storage.Initialize(&cfg.Database) == nil +} + +// dbOwnerUserID returns the ID of the first admin user (or first user) for use +// when creating records via DB that require a user_id. +func dbOwnerUserID(username string) string { + var user models.ModelPublicUsers + q := storage.DB.NewSelect().Model(&user) + if username != "" { + q = q.Where("username = ?", username) + } else { + q = q.Where("role = ?", "admin") + } + if err := q.Limit(1).Scan(context.Background()); err != nil { + // fall back to any user + if err2 := storage.DB.NewSelect().Model(&user).Limit(1).Scan(context.Background()); err2 != nil { + return "" + } + } + return user.ID.String() +} + func initUserDB() { cfgPath := resolveServerConfigPath() if cfgPath == "" { @@ -183,39 +252,29 @@ func listUsers() { } func addUser() { - scanner := bufio.NewScanner(os.Stdin) - - fmt.Print("Username: ") - scanner.Scan() - username := strings.TrimSpace(scanner.Text()) + username := addUsername if username == "" { - fmt.Fprintln(os.Stderr, "Error: username is required") - os.Exit(1) + username = promptRequired("Username") } - fmt.Print("Email: ") - scanner.Scan() - email := strings.TrimSpace(scanner.Text()) + email := addEmail if email == "" { - fmt.Fprintln(os.Stderr, "Error: email is required") - os.Exit(1) + email = promptRequired("Email") } - fmt.Print("Full Name (optional): ") - scanner.Scan() - fullName := strings.TrimSpace(scanner.Text()) + fullName := addFullName + if fullName == "" { + fullName = promptLine("Full Name (optional)", "") + } - fmt.Print("Role (admin/user) [user]: ") - scanner.Scan() - role := strings.TrimSpace(scanner.Text()) + role := addRole if role == "" { - role = "user" + role = promptLine("Role (admin/user)", "user") } - password := readPassword("Password: ") + password := addPassword if password == "" { - fmt.Fprintln(os.Stderr, "Error: password is required") - os.Exit(1) + password = readPassword("Password: ") } hashedPw, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) @@ -240,8 +299,21 @@ func addUser() { fmt.Printf("User '%s' created\n", username) } -func setPassword(username string) { +func setPasswordDirect(username, password string) { user := findUserByUsername(username) + hashedPw, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + checkError(err) + _, err = storage.DB.NewUpdate().Model(user). + Set("password = ?", string(hashedPw)). + Set("updated_at = ?", time.Now()). + Where("id = ?", user.ID.String()). + Exec(context.Background()) + checkError(err) + fmt.Printf("Password updated for '%s'\n", username) +} + +func setPassword(username string) { + findUserByUsername(username) // validate exists password := readPassword("New password: ") confirm := readPassword("Confirm password: ") @@ -249,15 +321,15 @@ func setPassword(username string) { fmt.Fprintln(os.Stderr, "Error: passwords do not match") os.Exit(1) } - if password == "" { - fmt.Fprintln(os.Stderr, "Error: password cannot be empty") - os.Exit(1) - } hashedPw, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) checkError(err) - _, err = storage.DB.NewUpdate().Model(user). + var user models.ModelPublicUsers + err = storage.DB.NewSelect().Model(&user).Where("username = ?", username).Scan(context.Background()) + checkError(err) + + _, err = storage.DB.NewUpdate().Model(&user). Set("password = ?", string(hashedPw)). Set("updated_at = ?", time.Now()). Where("id = ?", user.ID.String()). @@ -286,36 +358,40 @@ func setUserActive(username string, active bool) { func updateUser(username string) { user := findUserByUsername(username) - scanner := bufio.NewScanner(os.Stdin) - fmt.Printf("Updating user '%s' (press Enter to keep current value)\n", username) + // If any flag was provided, apply flags only (non-interactive) + flagsProvided := updateUsername != "" || updateEmail != "" || updateFullName != "" || updateRole != "" - fmt.Printf("Username [%s]: ", user.Username.String()) - scanner.Scan() - if v := strings.TrimSpace(scanner.Text()); v != "" { - user.Username = resolvespec_common.NewSqlString(v) + if flagsProvided { + if updateUsername != "" { + user.Username = resolvespec_common.NewSqlString(updateUsername) + } + if updateEmail != "" { + user.Email = resolvespec_common.NewSqlString(updateEmail) + } + if updateFullName != "" { + user.FullName = resolvespec_common.NewSqlString(updateFullName) + } + if updateRole != "" { + user.Role = resolvespec_common.NewSqlString(updateRole) + } + } else { + fmt.Printf("Updating user '%s' (press Enter to keep current value)\n", username) + + if v := promptLine("Username", user.Username.String()); v != "" { + user.Username = resolvespec_common.NewSqlString(v) + } + if v := promptLine("Email", user.Email.String()); v != "" { + user.Email = resolvespec_common.NewSqlString(v) + } + if v := promptLine("Full Name", user.FullName.String()); v != "" { + user.FullName = resolvespec_common.NewSqlString(v) + } + if v := promptLine("Role (admin/user)", user.Role.String()); v != "" { + user.Role = resolvespec_common.NewSqlString(v) + } } - fmt.Printf("Email [%s]: ", user.Email.String()) - scanner.Scan() - if v := strings.TrimSpace(scanner.Text()); v != "" { - user.Email = resolvespec_common.NewSqlString(v) - } - - fmt.Printf("Full Name [%s]: ", user.FullName.String()) - scanner.Scan() - if v := strings.TrimSpace(scanner.Text()); v != "" { - user.FullName = resolvespec_common.NewSqlString(v) - } - - fmt.Printf("Role (admin/user) [%s]: ", user.Role.String()) - scanner.Scan() - if v := strings.TrimSpace(scanner.Text()); v != "" { - user.Role = resolvespec_common.NewSqlString(v) - } - - user.UpdatedAt = resolvespec_common.NewSqlTimeStamp(time.Now()) - _, err := storage.DB.NewUpdate().Model(user). Set("username = ?", user.Username.String()). Set("email = ?", user.Email.String()). @@ -333,9 +409,9 @@ func deleteUser(username string) { user := findUserByUsername(username) fmt.Printf("Delete user '%s'? This cannot be undone. [y/N]: ", username) - scanner := bufio.NewScanner(os.Stdin) - scanner.Scan() - if strings.ToLower(strings.TrimSpace(scanner.Text())) != "y" { + reader := bufio.NewReader(os.Stdin) + line, _ := reader.ReadString('\n') + if strings.ToLower(strings.TrimSpace(line)) != "y" { fmt.Println("Cancelled") return } @@ -356,14 +432,63 @@ func findUserByUsername(username string) *models.ModelPublicUsers { return &user } -func readPassword(prompt string) string { - fmt.Print(prompt) - pw, err := term.ReadPassword(int(os.Stdin.Fd())) - fmt.Println() - if err != nil { - scanner := bufio.NewScanner(os.Stdin) - scanner.Scan() - return strings.TrimSpace(scanner.Text()) +// promptLine prints a prompt and reads one line. Returns defaultVal if Enter pressed with no input. +func promptLine(label, defaultVal string) string { + if defaultVal != "" { + fmt.Printf("%s [%s]: ", label, defaultVal) + } else { + fmt.Printf("%s: ", label) + } + reader := bufio.NewReader(os.Stdin) + line, err := reader.ReadString('\n') + if err != nil && err != io.EOF { + return defaultVal + } + v := strings.TrimRight(line, "\r\n") + if v == "" { + return defaultVal + } + return v +} + +// promptRequired loops until a non-empty value is entered. +func promptRequired(label string) string { + for { + v := promptLine(label, "") + if v != "" { + return v + } + fmt.Printf("%s is required, please try again.\n", label) + } +} + +// readPassword reads a password with hidden input when on a TTY, plain text otherwise. +// Loops until a non-empty value is entered. +func readPassword(prompt string) string { + fd := int(os.Stdin.Fd()) + isTTY := term.IsTerminal(fd) + for { + fmt.Print(prompt) + var pw string + if isTTY { + p, err := term.ReadPassword(fd) + fmt.Println() + if err == nil { + pw = string(p) + } + } else { + reader := bufio.NewReader(os.Stdin) + p, err := reader.ReadString('\n') + fmt.Println() + if err != nil && err != io.EOF { + fmt.Fprintln(os.Stderr, "Error: cannot read password interactively. Use set-password instead.") + os.Exit(1) + } + pw = strings.TrimRight(p, "\r\n") + } + if pw != "" { + return pw + } + fmt.Println("Password cannot be empty, please try again.") } - return string(pw) }