diff --git a/transaction.go b/transaction.go index b179627f..23d74bfb 100644 --- a/transaction.go +++ b/transaction.go @@ -24,6 +24,7 @@ import "C" import ( "errors" "fmt" + "runtime" "runtime/cgo" "strings" "sync/atomic" @@ -253,7 +254,7 @@ func Start(service, user string, handler ConversationHandler) (*Transaction, err // StartFunc registers the handler func as a conversation handler and starts // the transaction (see Start() documentation). func StartFunc(service, user string, handler func(Style, string) (string, error)) (*Transaction, error) { - return Start(service, user, ConversationFunc(handler)) + return start(service, user, ConversationFunc(handler), "") } // StartConfDir initiates a new PAM transaction. Service is treated identically to @@ -281,6 +282,12 @@ func StartConfDir(service, user string, handler ConversationHandler, confDir str return start(service, user, handler, confDir) } +type panicHandlerFunc func(any) + +var defaultPanicHandler = func(v any) { panic(v) } + +var panicHandler panicHandlerFunc = defaultPanicHandler + func start(service, user string, handler ConversationHandler, confDir string) (*Transaction, error) { switch handler.(type) { case BinaryConversationHandler: @@ -295,6 +302,33 @@ func start(service, user string, handler ConversationHandler, confDir string) (* conv: &C.struct_pam_conv{}, c: cgo.NewHandle(handler), } + + callers := make([]uintptr, 10) + haveCallerInfo := runtime.Callers(2, callers) > 0 + runtime.SetFinalizer(t, func(t *Transaction) { + handle := atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&t.handle))) + if handle == nil { + return + } + if haveCallerInfo { + stackTrace := make([]string, 0, len(callers)) + for i := 0; i < len(callers); i += 1 { + frame, more := runtime.CallersFrames(callers).Next() + stackTrace = append(stackTrace, fmt.Sprintf("%s:%d", + frame.File, frame.Line)) + if !more { + break + } + callers = callers[1:] + } + panicHandler(fmt.Sprintf("Transaction %p was never ended. "+ + "Initialization was at:\n %s", + t, strings.Join(stackTrace, "\n "))) + } else { + panicHandler(fmt.Sprintf("Transaction %p was never ended", t)) + } + }) + C.init_pam_conv(t.conv, C.uintptr_t(t.c)) s := C.CString(service) defer C.free(unsafe.Pointer(s)) diff --git a/transaction_test.go b/transaction_test.go index a9211ee1..622da86c 100644 --- a/transaction_test.go +++ b/transaction_test.go @@ -6,7 +6,10 @@ import ( "os" "os/user" "path/filepath" + "regexp" + "runtime" "testing" + "time" ) func maybeEndTransaction(t *testing.T, tx *Transaction) { @@ -550,6 +553,72 @@ func Test_Status(t *testing.T) { } } +func Test_Finalizer(t *testing.T) { + if !CheckPamHasStartConfdir() { + t.Skip("this requires PAM with Conf dir support") + } + + func() { + tx, err := StartConfDir("permit-service", "", nil, "test-services") + defer maybeEndTransaction(t, tx) + if err != nil { + t.Fatalf("start #error: %v", err) + } + }() + + runtime.GC() + // sleep to switch to finalizer goroutine + time.Sleep(5 * time.Millisecond) +} + +func Test_FinalizerNotCleanedUp(t *testing.T) { + if !CheckPamHasStartConfdir() { + t.Skip("this requires PAM with Conf dir support") + } + + panicChan := make(chan any) + panicHandler = func(v any) { panicChan <- v } + go func() { time.Sleep(time.Second * 2); panicChan <- errors.New("") }() + + func() { + _, err := StartConfDir("permit-service", "", nil, "test-services") + if err != nil { + t.Fatalf("start #error: %v", err) + } + }() + + runtime.GC() + + panicMsg := (<-panicChan).(string) + panicHandler = defaultPanicHandler + + match, err := regexp.MatchString( + "Transaction 0x[0-9a-f]+ was never ended", panicMsg) + if err != nil { + t.Fatalf("unexpected error %v", err) + } + if !match { + t.Fatalf("no match in result:\n%s", panicMsg) + } + match, err = regexp.MatchString( + "transaction.go:[0-9]+", panicMsg) + if err != nil { + t.Fatalf("unexpected error %v", err) + } + if !match { + t.Fatalf("no match in result:\n%s", panicMsg) + } + match, err = regexp.MatchString( + "transaction_test.go:[0-9]+", panicMsg) + if err != nil { + t.Fatalf("unexpected error %v", err) + } + if !match { + t.Fatalf("no match in result:\n%s", panicMsg) + } + fmt.Println("I'd be done...") +} + func TestFailure_001(t *testing.T) { tx := Transaction{} _, err := tx.GetEnvList()