args.um

fn mk

Creates a new argument parser. Tip: Use args::stdargv() to get the arguments from the command line.

fn mk*(programName: str, programVer: str, programDesc: str, args: []str): Args {

fn args.mode

Adds a mode to the argument parser. Returns true if you are in that mode.

fn (args: ^Args) mode*(name: str, desc: str = ""): bool {

fn args.required

Adds a required argument to the argument parser.

fn (args: ^Args) required*(ptr: any, name: str, desc: str = ""): ^Arg {

fn args.optional

Adds an optional argument to the argument parser.

fn (args: ^Args) optional*(ptr: any, name: str, desc: str = ""): ^Arg {

fn args.requiredNext

Adds a positional required argument to the argument parser.

fn (args: ^Args) requiredNext*(ptr: any, name: str, desc: str = ""): ^Arg {

fn args.optionalNext

Adds a positional optional argument to the argument parser.

fn (args: ^Args) optionalNext*(ptr: any, name: str, desc: str = ""): ^Arg {

fn args.optionalCutoff

Allows you to capture arguments to pass them to another program.

fn (args: ^Args) optionalCutoff*(ptr: any, name: str, desc: str = ""): ^Arg {

fn args.requiredCutoff

Allows you to capture arguments to pass them to another program. Requires at least one argument.

fn (args: ^Args) requiredCutoff*(ptr: any, name: str, desc: str = ""): ^Arg {

fn arg.short

Adds a short name to the argument.

fn (arg: ^Arg) short*(ch: char) {

fn args.help

Adds a --help argument to the argument parser.

fn (args: ^Args) help*() {

fn args.parse

Parses the command line arguments, returns a list of errors.

fn (args: ^Args) parse*(): []Error {

fn args.getusage

Returns the usage/help message.

fn (args: ^Args) getusage*(): str {

fn args.usage

Prints the usage/help message and exits.

fn (args: ^Args) usage*() {

fn args.stdargv

Returns the command line arguments.

fn stdargv*(): []str {

fn args.check

Parses the command line arguments and exits if there are any errors.

fn (args: ^Args) check*() {

import (
    "std.um"
)

type (
    Mode* = struct {
        name: str
        desc: str
    }

    Arg* = struct {
        name: str
        desc: str
        value: any
        required: bool
        occurences: int
        positional: bool
        shortname: char
        cutoff: bool
    }

    Error* = struct {
        msg: str
    }

    Args* = struct {
        programName: str
        programVer: str
        programDesc: str
        sourceArgs: []str
        args: []^Arg
        modes: []Mode
        positional: []^Arg
        listPositional: ^Arg
        gethelp: bool
    }
)

fn ismaptype(v: any): bool {
    switch vv := type(v) {
        case ^map[str]bool: return true
        case ^map[str]str: return true
        case ^map[str]real: return true
        case ^map[str]real32: return true
        case ^map[str]int: return true
        case ^map[str]int32: return true
        case ^map[str]uint: return true
        case ^map[str]uint32: return true
        case ^map[str][]str: return true
        case ^map[str][]real: return true
        case ^map[str][]real32: return true
        case ^map[str][]int: return true
        case ^map[str][]int32: return true
        case ^map[str][]uint: return true
        case ^map[str][]uint32: return true
        default:
    }

    return false
}

fn isarrtype(v: any): bool {
    switch vv := type(v) {
        case ^[]str: return true
        case ^[]real: return true
        case ^[]real32: return true
        case ^[]int: return true
        case ^[]int32: return true
        case ^[]uint: return true
        case ^[]uint32: return true
        default:
    }

    return false
}

fn typelbl(v: any): str {
    switch vv := type(v) {
        case ^bool: return ""
        case ^str: return ""
        case ^real: return ""
        case ^real32: return ""
        case ^int: return ""
        case ^int32: return ""
        case ^uint: return ""
        case ^uint32: return ""
        case ^[]str: return ""
        case ^[]real: return ""
        case ^[]real32: return ""
        case ^[]int: return ""
        case ^[]int32: return ""
        case ^[]uint: return ""
        case ^[]uint32: return ""
        case ^map[str]bool: return ""
        case ^map[str]str: return " "
        case ^map[str]real: return " "
        case ^map[str]real32: return " "
        case ^map[str]int: return " "
        case ^map[str]int32: return " "
        case ^map[str]uint: return " "
        case ^map[str]uint32: return " "
        case ^map[str][]str: return " "
        case ^map[str][]real: return " "
        case ^map[str][]real32: return " "
        case ^map[str][]int: return " "
        case ^map[str][]int32: return " "
        case ^map[str][]uint: return " "
        case ^map[str][]uint32: return " "
        default:
    }

    return ""
}

fn argstart(s: str): bool {
    if s[0] == '-' && len(s) > 1 {
        return (s[1] >= 'a' && s[1] <= 'z') || (s[1] >= 'A' && s[1] <= 'Z') || s[1] == '_' || s[1] == '-'
    }
    return false
}

fn realstart(s: str): bool {
    if (s[0] == '-' || s[0] == '+' || s[0] == '.') && len(s) > 1 {
        return s[1] >= '0' && s[1] <= '9'
    }
    return s[0] >= '0' && s[0] <= '9'
}

fn intstart(s: str): bool {
    if (s[0] == '-' || s[0] == '+') && len(s) > 1 {
        return s[1] >= '0' && s[1] <= '9'
    }
    return s[0] >= '0' && s[0] <= '9'
}

fn uintstart(s: str): bool {
    if s[0] == '+' && len(s) > 1 {
        return s[1] >= '0' && s[1] <= '9'
    }
    return s[0] >= '0' && s[0] <= '9'
}

//~~fn mk
// Creates a new argument parser.
// Tip: Use `args::stdargv()` to get the arguments from the command line.
fn mk*(programName: str, programVer: str, programDesc: str, args: []str): Args {
//~~
    return {
        programName: programName, 
        programVer: programVer,
        programDesc: programDesc,
        sourceArgs: slice(args, 1) // skip the program name
    }
}

//~~fn args.mode
// Adds a mode to the argument parser.
// Returns true if you are in that mode.
fn (args: ^Args) mode*(name: str, desc: str = ""): bool {
//~~
    if len(args.positional) > 0 || args.listPositional != null {
        std::assert(false, "Cannot use mode after positional arguments")
    }

    args.modes = append(args.modes, Mode{name: name, desc: desc})
    
    if len(args.sourceArgs) == 0 {
        return false
    }

    if args.sourceArgs[0] == name {
        args.modes = {}
        args.sourceArgs = slice(args.sourceArgs, 1)
        return true
    }

    return false
}

fn (args: ^Args) arg(name: str, ptr: any, desc: str, required: bool, next: bool, cutoff: bool): ^Arg {
    arg := &Arg{name: name, desc: desc, value: ptr, required: required, positional: next, cutoff: cutoff}
    args.args = append(args.args, arg)
    
    if next {
        if v := ^bool(ptr); v != null {
            std::assert(false, "Cannot use bool as positional argument")
        }
        if args.listPositional != null {
            std::assert(false, "Cannot put a positional argument after a list positional argument")
        }
        if required && len(args.positional) > 0 && !args.positional[len(args.positional) - 1].required {
            std::assert(false, "Cannot use a required positional argument after optional positional argument")
        }
        
        if isarrtype(ptr) {
            args.listPositional = arg
        } else  {
            args.positional = append(args.positional, arg)
        }
    }

    return arg
}

//~~fn args.required
// Adds a required argument to the argument parser.
fn (args: ^Args) required*(ptr: any, name: str, desc: str = ""): ^Arg {
//~~
    return args.arg(name, ptr, desc, true, false, false)
}

//~~fn args.optional
// Adds an optional argument to the argument parser.
fn (args: ^Args) optional*(ptr: any, name: str, desc: str = ""): ^Arg {
//~~
    return args.arg(name, ptr, desc, false, false, false)
}

//~~fn args.requiredNext
// Adds a positional required argument to the argument parser.
fn (args: ^Args) requiredNext*(ptr: any, name: str, desc: str = ""): ^Arg {
//~~
    return args.arg(name, ptr, desc, true, true, false)
}

//~~fn args.optionalNext
// Adds a positional optional argument to the argument parser.
fn (args: ^Args) optionalNext*(ptr: any, name: str, desc: str = ""): ^Arg {
//~~
    return args.arg(name, ptr, desc, false, true, false)
}

//~~fn args.optionalCutoff
// Allows you to capture arguments to pass them to another program.
fn (args: ^Args) optionalCutoff*(ptr: any, name: str, desc: str = ""): ^Arg {
//~~
    return args.arg(name, ptr, desc, false, true, true)
}

//~~fn args.requiredCutoff
// Allows you to capture arguments to pass them to another program.
// Requires at least one argument.
fn (args: ^Args) requiredCutoff*(ptr: any, name: str, desc: str = ""): ^Arg {
//~~
    return args.arg(name, ptr, desc, true, true, true)
}


//~~fn arg.short
// Adds a short name to the argument.
fn (arg: ^Arg) short*(ch: char) {
//~~
    arg.shortname = ch
}

//~~fn args.help
// Adds a --help argument to the argument parser.
fn (args: ^Args) help*() {
//~~
    args.optional(&args.gethelp, "help", "Show this help message").short('h')
}

fn (arg: ^Arg) parsearg(argv: []str, i: int, errors: ^[]Error): int {
    arg.occurences += 1
    if v := ^bool(arg.value); v != null {
        v ^= !(v^)
        return i
    }

    key := ""

    if ismaptype(arg.value) {
        if i >= len(argv) {
            errors ^= append(errors^, Error{"Missing key for argument --" + arg.name})
            return i
        }
        key = argv[i]
        at := -1
        for i, c in key {
            if c == '=' {
                at = i
                break
            }
        }

        if at != -1 {
            key = slice(key, 0, at)
            argv[i] = slice(argv[i], at+1)
        } else {
            i++
        }
    }

    if v := ^map[str]bool(arg.value); v != null {
        v[key] = true
        return i 
    }

    if i >= len(argv) {
        errors ^= append(errors^, Error{"Missing value for argument --" + arg.name})
        return i
    }

    // validate first
    isvalid := false
    expected := ""

    switch v := type(arg.value) {
        case ^str: expected = "string"; isvalid = true
        case ^real: expected = "real number"; isvalid = realstart(argv[i])
        case ^real32: expected = "real number"; isvalid = realstart(argv[i])
        case ^int: expected = "integer"; isvalid = intstart(argv[i])
        case ^int32: expected = "integer"; isvalid = intstart(argv[i])
        case ^uint: expected = "unsigned integer"; isvalid = uintstart(argv[i])
        case ^uint32: expected = "unsigned integer"; isvalid = uintstart(argv[i])
        case ^[]str: expected = "string"; isvalid = true
        case ^[]real: expected = "real number"; isvalid = realstart(argv[i])
        case ^[]real32: expected = "real number"; isvalid = realstart(argv[i])
        case ^[]int: expected = "integer"; isvalid = intstart(argv[i])
        case ^[]int32: expected = "integer"; isvalid = intstart(argv[i])
        case ^[]uint: expected = "unsigned integer"; isvalid = uintstart(argv[i])
        case ^[]uint32: expected = "unsigned integer"; isvalid = uintstart(argv[i])
        case ^map[str]str: expected = "string"; isvalid = true
        case ^map[str]real: expected = "real number"; isvalid = realstart(argv[i])
        case ^map[str]real32: expected = "real number"; isvalid = realstart(argv[i])
        case ^map[str]int: expected = "integer"; isvalid = intstart(argv[i])
        case ^map[str]int32: expected = "integer"; isvalid = intstart(argv[i])
        case ^map[str]uint: expected = "unsigned integer"; isvalid = uintstart(argv[i])
        case ^map[str]uint32: expected = "unsigned integer"; isvalid = uintstart(argv[i])
        case ^map[str][]str: expected = "string"; isvalid = true
        case ^map[str][]real: expected = "real number"; isvalid = realstart(argv[i])
        case ^map[str][]real32: expected = "real number"; isvalid = realstart(argv[i])
        case ^map[str][]int: expected = "integer"; isvalid = intstart(argv[i])
        case ^map[str][]int32: expected = "integer"; isvalid = intstart(argv[i])
        case ^map[str][]uint: expected = "unsigned integer"; isvalid = uintstart(argv[i])
        case ^map[str][]uint32: expected = "unsigned integer"; isvalid = uintstart(argv[i])
        default: std::assert(false, "Unsupported argument type")
    }

    if !isvalid {
        errors ^= append(errors^, Error{"Expected "+expected+" for --"+arg.name})
    }

    switch v := type(arg.value) {
        case ^str: v ^= argv[i]
        case ^real: v ^= std::atof(argv[i])
        case ^real32: v ^= std::atof(argv[i])
        case ^int: v ^= std::atoi(argv[i])
        case ^int32: v ^= std::atoi(argv[i])
        case ^uint: v ^= std::atoi(argv[i])
        case ^uint32: v ^= std::atoi(argv[i])
        case ^[]real: v ^= append(v^, std::atof(argv[i]))
        case ^[]real32: v ^= append(v^, std::atof(argv[i]))
        case ^[]int: v ^= append(v^, std::atoi(argv[i]))
        case ^[]int32: v ^= append(v^, std::atoi(argv[i]))
        case ^[]uint: v ^= append(v^, std::atoi(argv[i]))
        case ^[]uint32: v ^= append(v^, std::atoi(argv[i]))
        case ^[]str: v ^= append(v^, argv[i])
        case ^map[str]str: v[key] = argv[i]
        case ^map[str]real: v[key] = std::atof(argv[i])
        case ^map[str]real32: v[key] = std::atof(argv[i])
        case ^map[str]int: v[key] = std::atoi(argv[i])
        case ^map[str]int32: v[key] = std::atoi(argv[i])
        case ^map[str]uint: v[key] = std::atoi(argv[i])
        case ^map[str]uint32: v[key] = std::atoi(argv[i])
        case ^map[str][]str: v[key] = append(v[key], argv[i])
        case ^map[str][]real: v[key] = append(v[key], std::atof(argv[i]))
        case ^map[str][]real32: v[key] = append(v[key], std::atof(argv[i]))
        case ^map[str][]int: v[key] = append(v[key], std::atoi(argv[i]))
        case ^map[str][]int32: v[key] = append(v[key], std::atoi(argv[i]))
        case ^map[str][]uint: v[key] = append(v[key], std::atoi(argv[i]))
        case ^map[str][]uint32: v[key] = append(v[key], std::atoi(argv[i]))
        default:
    }

    i++

    return i
}

//~~fn args.parse
// Parses the command line arguments, returns a list of errors.
fn (args: ^Args) parse*(): []Error {
//~~
    errors := []Error{}
    argv := args.sourceArgs

    positionalIndex := 0
    i := 0
    for i < len(argv) {
        found := false

        for _, arg in args.args {
            if "--"+arg.name == argv[i] {
                found = true
                i++
                i = arg.parsearg(argv, i, &errors)
                break
            } else if arg.shortname != '\0' && "-"+arg.shortname == argv[i] {
                found = true
                i++
                i = arg.parsearg(argv, i, &errors)
                break
            }
        }

        if !found {
            if argstart(argv[i]) {
                errors = append(errors, Error{"Unknown argument: " + argv[i]})
                i++
            } else if positionalIndex < len(args.positional) {
                i = args.positional[positionalIndex].parsearg(argv, i, &errors)
                positionalIndex++
            } else if args.listPositional != null {
                i = args.listPositional.parsearg(argv, i, &errors)
                if args.listPositional.cutoff {
                    for i < len(argv) {
                        args.listPositional.parsearg(argv, i, &errors)
                        i++
                    }
                    break
                }
            } else {    
                errors = append(errors, Error{"Stray argument " + argv[i]})
                i++
            }
        }
    }

    args.sourceArgs = {}

    for i := positionalIndex; i < len(args.positional); i++ {
        if args.positional[i].required {
            errors = append(errors, Error{"Missing argument for " + args.positional[i].name})
        }
    }

    if args.listPositional != null && args.listPositional.occurences == 0 && args.listPositional.required {
        errors = append(errors, Error{"Need at least one argument for " + args.listPositional.name})
    }

    for _, arg in args.args {
        if arg.required && arg.occurences == 0 && !arg.positional {
            errors = append(errors, Error{"Missing required argument --" + arg.name})
        }
    }

    return errors
}

//~~fn args.getusage
// Returns the usage/help message.
fn (args: ^Args) getusage*(): str {
//~~
    result := sprintf("%s %s - %s\n", args.programName, args.programVer, args.programDesc)

    if len(args.args) > 0 {
        names := []str{}
        maxLen := 0
        for i, arg in args.args {
            if arg.shortname != '\0' {
                names = append(names, sprintf("--%s, -%c", arg.name, arg.shortname))
            } else {
                names = append(names, sprintf("--%s", arg.name))
            }
            if len(names[i]) > maxLen {
                maxLen = len(names[i])
            }
        }

        result += sprintf("\nArguments:\n")
        for i, arg in args.args {
            optstr := "(required)"
            if !arg.required {
                if v := ^bool(arg.value); v != null {
                    if v^ == true && arg.name != "help" {
                        optstr = "(default: true)"
                    } else {
                        optstr = ""
                    }
                } else if !isarrtype(arg.value) && !ismaptype(arg.value) {
                    optstr = sprintf("(default: %v)", arg.value)
                } else {
                    optstr = ""
                }
            }

            for len(names[i]) < maxLen {
                names[i] += " "
            }

            if arg.desc != "" {
                result += sprintf("\t%s %-11s - %s %s\n", names[i], typelbl(arg.value), arg.desc, optstr)
            } else {
                result += sprintf("\t%s %-11s %s\n", names[i], typelbl(arg.value), optstr)
            }
        }
    }

    if len(args.modes) > 0 {
        names := []str{}
        maxLen := 0
        for i, mode in args.modes {
            names = append(names, mode.name)
            if len(names[i]) > maxLen {
                maxLen = len(names[i])
            }
        }

        result += sprintf("\nModes:\n")
        for i, mode in args.modes {
            for len(names[i]) < maxLen {
                names[i] += " "
            }

            if mode.desc != "" {
                result += sprintf("\t%s - %s\n", names[i], mode.desc)
            } else {
                result += sprintf("\t%s\n", names[i])
            }
        }
    }

    if len(args.positional) > 0 || args.listPositional != null {
        result += "\nFormat:\n\t"
        for _, arg in args.positional {
            if arg.required {
                result += sprintf("<%s> ", arg.name)
            } else {
                result += sprintf("[%s] ", arg.name)
            }
        }

        if args.listPositional != null {
            if args.listPositional.cutoff {
                result += ": "
            }

            if args.listPositional.required {
                result += sprintf("<%s...>", args.listPositional.name)
            } else {
                result += sprintf("[%s...]", args.listPositional.name)
            }
        }

        result += "\n"
    }


    return result
}

//~~fn args.usage
// Prints the usage/help message and exits.
fn (args: ^Args) usage*() {
//~~
    printf("%s", args.getusage())

    for _, arg in args.args {
        if args.gethelp && arg.occurences != 0 && arg.name != "help" {
            printf("\nWarning: Argument to %s is ignored", arg.name)
        }
    }

    exit(0)
}

//~~fn args.stdargv
// Returns the command line arguments.
fn stdargv*(): []str {
//~~
    list := make([]str, std::argc())

    for i := 0; i < std::argc(); i++ {
        list[i] = std::argv(i)
    }

    return list
} 

//~~fn args.check
// Parses the command line arguments and exits if there are any errors.
fn (args: ^Args) check*() {
//~~
    errors := args.parse()
    if args.gethelp {
        args.usage()
    } else if len(errors) > 0 {
        printf("%s", args.getusage())
        printf("\n")
        for _, error in errors {
            printf("Error: %s\n", error.msg)
        }
        exit(1)
    }
}