pass a io.Writer to loggers and outputs so we can test the cmdline

This commit is contained in:
2024-01-01 20:53:05 +01:00
parent d1faa10a52
commit 8455c193eb
2 changed files with 19 additions and 15 deletions

View File

@@ -19,6 +19,7 @@ package main
import ( import (
"errors" "errors"
"fmt" "fmt"
"io"
"os" "os"
"path/filepath" "path/filepath"
"runtime" "runtime"
@@ -87,7 +88,7 @@ func (c *Config) IncrImgs(num int) {
} }
// load commandline flags and config file // load commandline flags and config file
func InitConfig() (*Config, error) { func InitConfig(w io.Writer) (*Config, error) {
var k = koanf.New(".") var k = koanf.New(".")
// determine template based on os // determine template based on os
@@ -109,7 +110,7 @@ func InitConfig() (*Config, error) {
// setup custom usage // setup custom usage
f := flag.NewFlagSet("config", flag.ContinueOnError) f := flag.NewFlagSet("config", flag.ContinueOnError)
f.Usage = func() { f.Usage = func() {
fmt.Println(Usage) fmt.Fprintln(w, Usage)
os.Exit(0) os.Exit(0)
} }

21
main.go
View File

@@ -20,6 +20,7 @@ package main
import ( import (
"errors" "errors"
"fmt" "fmt"
"io"
"log/slog" "log/slog"
"os" "os"
"runtime/debug" "runtime/debug"
@@ -30,10 +31,10 @@ import (
const LevelNotice = slog.Level(2) const LevelNotice = slog.Level(2)
func main() { func main() {
os.Exit(Main()) os.Exit(Main(os.Stdout))
} }
func Main() int { func Main(w io.Writer) int {
logLevel := &slog.LevelVar{} logLevel := &slog.LevelVar{}
opts := &tint.Options{ opts := &tint.Options{
Level: logLevel, Level: logLevel,
@@ -49,22 +50,22 @@ func Main() int {
} }
logLevel.Set(LevelNotice) logLevel.Set(LevelNotice)
var handler slog.Handler = tint.NewHandler(os.Stdout, opts) handler := tint.NewHandler(w, opts)
logger := slog.New(handler) logger := slog.New(handler)
slog.SetDefault(logger) slog.SetDefault(logger)
conf, err := InitConfig() conf, err := InitConfig(w)
if err != nil { if err != nil {
return Die(err) return Die(err)
} }
if conf.Showversion { if conf.Showversion {
fmt.Printf("This is kleingebaeck version %s\n", VERSION) fmt.Fprintf(w, "This is kleingebaeck version %s\n", VERSION)
return 0 return 0
} }
if conf.Showhelp { if conf.Showhelp {
fmt.Println(Usage) fmt.Fprintln(w, Usage)
return 0 return 0
} }
@@ -90,7 +91,7 @@ func Main() int {
} }
logLevel.Set(slog.LevelDebug) logLevel.Set(slog.LevelDebug)
var handler slog.Handler = tint.NewHandler(os.Stdout, opts) handler := tint.NewHandler(w, opts)
debuglogger := slog.New(handler).With( debuglogger := slog.New(handler).With(
slog.Group("program_info", slog.Group("program_info",
slog.Int("pid", os.Getpid()), slog.Int("pid", os.Getpid()),
@@ -100,6 +101,8 @@ func Main() int {
slog.SetDefault(debuglogger) slog.SetDefault(debuglogger)
} }
// defaultlogger := log.Default()
// defaultlogger.SetOutput(w)
slog.Debug("config", "conf", conf) slog.Debug("config", "conf", conf)
// prepare output dir // prepare output dir
@@ -131,10 +134,10 @@ func Main() int {
if conf.StatsCountAds == 1 { if conf.StatsCountAds == 1 {
adstr = "ad" adstr = "ad"
} }
fmt.Printf("Successfully downloaded %d %s with %d images to %s.\n", fmt.Fprintf(w, "Successfully downloaded %d %s with %d images to %s.\n",
conf.StatsCountAds, adstr, conf.StatsCountImages, conf.Outdir) conf.StatsCountAds, adstr, conf.StatsCountImages, conf.Outdir)
} else { } else {
fmt.Printf("No ads found.") fmt.Fprintf(w, "No ads found.")
} }
return 0 return 0