Skip to content

Commit

Permalink
feat: increase testability
Browse files Browse the repository at this point in the history
Getting rid of global state is a nice step forward in improving
testability.
Caller code (and more importantly tests) can now fully control
the registered fields or the logger used to produce messages.

All while not vastly changing the current interface, which keeps the
global logging functions it used to have. Those obviously still rely on
shared global state.
  • Loading branch information
Crocmagnon authored and fsamin committed Dec 5, 2023
1 parent 84d5c3a commit 01b6af8
Show file tree
Hide file tree
Showing 5 changed files with 206 additions and 94 deletions.
28 changes: 20 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,30 +15,44 @@ It offers a convenient way to keep your logs when you are running unit tests.
```

## How to use
### Global logger

By default, it is initialized to wrap `logrus` package. You can override `log.Factory` to use the logger library you want.

First register `fields`

```golang
const myField = log.Field("component")
const myField = log.Field("component")

func init() {
log.RegisterField(myField)
}
func init() {
log.RegisterField(myField)
}
```

Then add `fields` as values to you current context.

```golang
ctx = context.WithValue(ctx, myField, "myComponent")
ctx = context.WithValue(ctx, myField, "myComponent")
```

Finally log as usual.
```golang
log.Info(ctx, "this is a log")
log.Info(ctx, "this is a log")
```

### Logger instance
You can opt to use a logger instance instead of the global state:
```golang
logger := log.NewWithFactory(log.NewLogrusWrapper(logrus.New()))

// Registration is scoped to the logger instance
logger.RegisterField(myField)

logger.Info(ctx, "this is a log")
```

A typical use case may be to instanciate a logger at app startup and storing it in a struct for use in other methods.

## Examples

```golang
Expand Down Expand Up @@ -74,5 +88,3 @@ Log errors easily.
log.ErrorWithStackTrace(ctx, err) // will produce a nice stack_trace field
)
```
217 changes: 148 additions & 69 deletions log.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,119 +5,146 @@ import (
"fmt"
"runtime"
"sort"
"sync"

"github.com/pkg/errors"
"github.com/sirupsen/logrus"
)

func RegisterField(fields ...Field) {
registeredFieldsMutex.Lock()
defer registeredFieldsMutex.Unlock()
const (
FieldSourceFile = Field("source_file")
FieldSourceLine = Field("source_line")
FieldCaller = Field("caller")
FieldStackTrace = Field("stack_trace")
)

var global *Logger
var Factory WrapperFactoryFunc

func init() {
global = New()
global.callerFrameToSkip = 3
global.globalFactory = true
}

type Logger struct {
registeredFields []Field
registeredFieldsMutex sync.Mutex
excludeRules map[Field]any
excludeRulesMutex sync.Mutex
factory WrapperFactoryFunc
callerFrameToSkip int

// if globalFactory is true, Logger.factory is ignored and Factory is used instead.
globalFactory bool
}

func New() *Logger {
return NewWithFactory(NewLogrusWrapper(logrus.StandardLogger()))
}

func NewWithFactory(factory WrapperFactoryFunc) *Logger {
logger := &Logger{factory: factory, callerFrameToSkip: 2, excludeRules: make(map[Field]any)}
logger.RegisterDefaultFields()
return logger
}

func (l *Logger) RegisterField(fields ...Field) {
l.registeredFieldsMutex.Lock()
defer l.registeredFieldsMutex.Unlock()

for _, f := range fields {
var exist bool
for _, existingF := range registeredFields {
for _, existingF := range l.registeredFields {
if f == existingF {
exist = true
break
}
}
if !exist {
registeredFields = append(registeredFields, f)
l.registeredFields = append(l.registeredFields, f)
}
}

sort.Slice(registeredFields, func(i, j int) bool {
return registeredFields[i] < registeredFields[j]
sort.Slice(l.registeredFields, func(i, j int) bool {
return l.registeredFields[i] < l.registeredFields[j]
})
}

func UnregisterField(fields ...Field) {
registeredFieldsMutex.Lock()
defer registeredFieldsMutex.Unlock()
func (l *Logger) UnregisterField(fields ...Field) {
l.registeredFieldsMutex.Lock()
defer l.registeredFieldsMutex.Unlock()

loop:
for _, f := range fields {
for i, existingF := range registeredFields {
for i, existingF := range l.registeredFields {
if f == existingF {
registeredFields = append(registeredFields[:i], registeredFields[i+1:]...)
l.registeredFields = append(l.registeredFields[:i], l.registeredFields[i+1:]...)
goto loop
}
}
}

sort.Slice(registeredFields, func(i, j int) bool {
return registeredFields[i] < registeredFields[j]
sort.Slice(l.registeredFields, func(i, j int) bool {
return l.registeredFields[i] < l.registeredFields[j]
})
}

// GetRegisteredFields returns a copy of the registered fields.
func GetRegisteredFields() []Field {
fields := make([]Field, len(registeredFields))
copy(fields, registeredFields)
func (l *Logger) GetRegisteredFields() []Field {
fields := make([]Field, len(l.registeredFields))
copy(fields, l.registeredFields)
return fields
}

// CallerFrameToSkip correspond to the number of frame to skip while retrieving the caller stack
var CallerFrameToSkip = 2

func Skip(field Field, value interface{}) {
excludeRulesMutex.Lock()
defer excludeRulesMutex.Unlock()

if excludeRules == nil {
excludeRules = make(map[Field]any)
}

excludeRules[field] = value
func (l *Logger) RegisterDefaultFields() {
l.RegisterField(FieldSourceFile, FieldSourceLine, FieldCaller, FieldStackTrace)
}

func Debug(ctx context.Context, format string, args ...interface{}) {
call(ctx, LevelDebug, format, args...)
}
func (l *Logger) Skip(field Field, value interface{}) {
l.excludeRulesMutex.Lock()
defer l.excludeRulesMutex.Unlock()

func Info(ctx context.Context, format string, args ...interface{}) {
call(ctx, LevelInfo, format, args...)
l.excludeRules[field] = value
}

func Warn(ctx context.Context, format string, args ...interface{}) {
call(ctx, LevelWarn, format, args...)
func (l *Logger) Debug(ctx context.Context, format string, args ...interface{}) {
l.call(ctx, LevelDebug, format, args...)
}

func Error(ctx context.Context, format string, args ...interface{}) {
call(ctx, LevelError, format, args...)
func (l *Logger) Info(ctx context.Context, format string, args ...interface{}) {
l.call(ctx, LevelInfo, format, args...)
}

func Fatal(ctx context.Context, format string, args ...interface{}) {
call(ctx, LevelFatal, format, args...)
func (l *Logger) Warn(ctx context.Context, format string, args ...interface{}) {
l.call(ctx, LevelWarn, format, args...)
}

func Panic(ctx context.Context, format string, args ...interface{}) {
call(ctx, LevelPanic, format, args...)
func (l *Logger) Error(ctx context.Context, format string, args ...interface{}) {
l.call(ctx, LevelError, format, args...)
}

var (
FieldSourceFile = Field("source_file")
FieldSourceLine = Field("source_line")
FieldCaller = Field("caller")
FieldStackTrace = Field("stack_trace")
)

func init() {
RegisterDefaultFields()
func (l *Logger) Fatal(ctx context.Context, format string, args ...interface{}) {
l.call(ctx, LevelFatal, format, args...)
}

func RegisterDefaultFields() {
RegisterField(FieldSourceFile, FieldSourceLine, FieldCaller, FieldStackTrace)
func (l *Logger) Panic(ctx context.Context, format string, args ...interface{}) {
l.call(ctx, LevelPanic, format, args...)
}

func call(ctx context.Context, level Level, format string, args ...interface{}) {
entry := Factory()
func (l *Logger) call(ctx context.Context, level Level, format string, args ...interface{}) {
var entry Wrapper

if l.globalFactory {
entry = Factory()
} else {
entry = l.factory()
}

if level < entry.GetLevel() {
return
}

pc, file, line, ok := runtime.Caller(CallerFrameToSkip)
pc, file, line, ok := runtime.Caller(l.callerFrameToSkip)
if ok {
ctx = context.WithValue(ctx, FieldSourceFile, file)
ctx = context.WithValue(ctx, FieldSourceLine, line)
Expand All @@ -127,10 +154,10 @@ func call(ctx context.Context, level Level, format string, args ...interface{})
}
}

for _, k := range registeredFields {
for _, k := range l.registeredFields {
v := ctx.Value(k)
if v != nil {
if exludeValue, has := excludeRules[k]; has {
if exludeValue, has := l.excludeRules[k]; has {
if v == exludeValue {
return
}
Expand Down Expand Up @@ -160,6 +187,22 @@ func call(ctx context.Context, level Level, format string, args ...interface{})
}
}

func (l *Logger) ErrorWithStackTrace(ctx context.Context, err error) {
ctx = ContextWithStackTrace(ctx, err)
l.call(ctx, LevelError, err.Error())
}

func (l *Logger) FieldValues(ctx context.Context) map[Field]interface{} {
res := make(map[Field]interface{}, 10)
for _, k := range l.registeredFields {
v := ctx.Value(k)
if v != nil {
res[k] = v
}
}
return res
}

type StackTracer interface {
StackTrace() errors.StackTrace
}
Expand All @@ -172,18 +215,54 @@ func ContextWithStackTrace(ctx context.Context, err error) context.Context {
return ctx
}

func RegisterField(fields ...Field) {
global.RegisterField(fields...)
}

func UnregisterField(fields ...Field) {
global.UnregisterField(fields...)
}

func GetRegisteredFields() []Field {
return global.GetRegisteredFields()
}

func RegisterDefaultFields() {
global.RegisterDefaultFields()
}

func Skip(field Field, value interface{}) {
global.Skip(field, value)
}

func Debug(ctx context.Context, format string, args ...interface{}) {
global.Debug(ctx, format, args...)
}

func Info(ctx context.Context, format string, args ...interface{}) {
global.Info(ctx, format, args...)
}

func Warn(ctx context.Context, format string, args ...interface{}) {
global.Warn(ctx, format, args...)
}

func Error(ctx context.Context, format string, args ...interface{}) {
global.Error(ctx, format, args...)
}

func Fatal(ctx context.Context, format string, args ...interface{}) {
global.Fatal(ctx, format, args...)
}

func Panic(ctx context.Context, format string, args ...interface{}) {
global.Panic(ctx, format, args...)
}

func ErrorWithStackTrace(ctx context.Context, err error) {
ctx = ContextWithStackTrace(ctx, err)
call(ctx, LevelError, err.Error())
global.ErrorWithStackTrace(ctx, err)
}

func FieldValues(ctx context.Context) map[Field]interface{} {
res := make(map[Field]interface{}, 10)
for _, k := range registeredFields {
v := ctx.Value(k)
if v != nil {
res[k] = v
}
}
return res
return global.FieldValues(ctx)
}
Loading

0 comments on commit 01b6af8

Please sign in to comment.