refactor: reuse logic for config init and set

This commit is contained in:
Henrique Dias
2025-11-17 09:16:54 +01:00
parent 8c5dc7641e
commit 89be0b1873
3 changed files with 132 additions and 229 deletions

View File

@@ -37,6 +37,11 @@ func addConfigFlags(flags *pflag.FlagSet) {
flags.Uint("minimumPasswordLength", settings.DefaultMinimumPasswordLength, "minimum password length for new users")
flags.String("shell", "", "shell command to which other commands should be appended")
// NB: these are string so they can be presented as octal in the help text
// as that's the conventional representation for modes in Unix.
flags.String("fileMode", fmt.Sprintf("%O", settings.DefaultFileMode), "mode bits that new files are created with")
flags.String("dirMode", fmt.Sprintf("%O", settings.DefaultDirMode), "mode bits that new directories are created with")
flags.String("auth.method", string(auth.MethodJSONAuth), "authentication type")
flags.String("auth.header", "", "HTTP header for auth.method=proxy")
flags.String("auth.command", "", "command for auth.method=hook")
@@ -52,11 +57,6 @@ func addConfigFlags(flags *pflag.FlagSet) {
flags.Bool("branding.disableExternal", false, "disable external links such as GitHub links")
flags.Bool("branding.disableUsedPercentage", false, "disable used disk percentage graph")
// NB: these are string so they can be presented as octal in the help text
// as that's the conventional representation for modes in Unix.
flags.String("fileMode", fmt.Sprintf("%O", settings.DefaultFileMode), "mode bits that new files are created with")
flags.String("dirMode", fmt.Sprintf("%O", settings.DefaultDirMode), "mode bits that new directories are created with")
flags.Uint64("tus.chunkSize", settings.DefaultTusChunkSize, "the tus chunk size")
flags.Uint16("tus.retryCount", settings.DefaultTusRetryCount, "the tus retry count")
}
@@ -266,3 +266,118 @@ func printSettings(ser *settings.Server, set *settings.Settings, auther auth.Aut
fmt.Printf("\nAuther configuration (raw):\n\n%s\n\n", string(b))
return nil
}
func getSettings(flags *pflag.FlagSet, set *settings.Settings, ser *settings.Server, auther auth.Auther, all bool) (auth.Auther, error) {
errs := []error{}
hasAuth := false
visit := func(flag *pflag.Flag) {
var err error
switch flag.Name {
// Server flags from [addServerFlags]
case "address":
ser.Address, err = flags.GetString(flag.Name)
case "log":
ser.Log, err = flags.GetString(flag.Name)
case "port":
ser.Port, err = flags.GetString(flag.Name)
case "cert":
ser.TLSCert, err = flags.GetString(flag.Name)
case "key":
ser.TLSKey, err = flags.GetString(flag.Name)
case "root":
ser.Root, err = flags.GetString(flag.Name)
case "socket":
ser.Socket, err = flags.GetString(flag.Name)
case "baseURL":
ser.BaseURL, err = flags.GetString(flag.Name)
case "tokenExpirationTime":
ser.TokenExpirationTime, err = flags.GetString(flag.Name)
case "disableThumbnails":
ser.EnableThumbnails, err = flags.GetBool(flag.Name)
ser.EnableThumbnails = !ser.EnableThumbnails
case "disablePreviewResize":
ser.ResizePreview, err = flags.GetBool(flag.Name)
ser.ResizePreview = !ser.ResizePreview
case "disableExec":
ser.EnableExec, err = flags.GetBool(flag.Name)
ser.EnableExec = !ser.EnableExec
case "disableTypeDetectionByHeader":
ser.TypeDetectionByHeader, err = flags.GetBool(flag.Name)
ser.TypeDetectionByHeader = !ser.TypeDetectionByHeader
// Settings flags from [addConfigFlags]
case "signup":
set.Signup, err = flags.GetBool(flag.Name)
case "hideLoginButton":
set.HideLoginButton, err = flags.GetBool(flag.Name)
case "createUserDir":
set.CreateUserDir, err = flags.GetBool(flag.Name)
case "minimumPasswordLength":
set.MinimumPasswordLength, err = flags.GetUint(flag.Name)
case "shell":
var shell string
shell, err = flags.GetString(flag.Name)
if err == nil {
set.Shell = convertCmdStrToCmdArray(shell)
}
case "auth.method":
hasAuth = true
case "branding.name":
set.Branding.Name, err = flags.GetString(flag.Name)
case "branding.theme":
set.Branding.Theme, err = flags.GetString(flag.Name)
case "branding.color":
set.Branding.Color, err = flags.GetString(flag.Name)
case "branding.files":
set.Branding.Files, err = flags.GetString(flag.Name)
case "branding.disableExternal":
set.Branding.DisableExternal, err = flags.GetBool(flag.Name)
case "branding.disableUsedPercentage":
set.Branding.DisableUsedPercentage, err = flags.GetBool(flag.Name)
case "fileMode":
set.FileMode, err = getAndParseFileMode(flags, flag.Name)
case "dirMode":
set.DirMode, err = getAndParseFileMode(flags, flag.Name)
case "tus.chunkSize":
set.Tus.ChunkSize, err = flags.GetUint64(flag.Name)
case "tus.retryCount":
set.Tus.RetryCount, err = flags.GetUint16(flag.Name)
}
if err != nil {
errs = append(errs, err)
}
}
if all {
flags.VisitAll(visit)
} else {
flags.Visit(visit)
}
err := nerrors.Join(errs...)
if err != nil {
return nil, err
}
err = getUserDefaults(flags, &set.Defaults, all)
if err != nil {
return nil, err
}
if all {
set.AuthMethod, auther, err = getAuthentication(flags)
if err != nil {
return nil, err
}
} else {
set.AuthMethod, auther, err = getAuthentication(flags, hasAuth, set, auther)
if err != nil {
return nil, err
}
}
return auther, nil
}