From f7e71e1d48962b062696a59d11802c981e7f1ba0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marco=20Trevisan=20=28Trevi=C3=B1o=29?= Date: Fri, 13 Oct 2023 19:24:49 +0200 Subject: [PATCH] module-transaction: Do not allow parallel conversations by default Pam conversations per se may also run in parallel, but this implies that the application supports this. Since this normally not the case, do not create modules that may invoke the pam conversations in parallel by default, adding a mutex to protect such calls. --- cmd/pam-moduler/moduler.go | 12 ++++++-- .../integration-tester-module.go | 2 +- module-transaction.go | 29 +++++++++++++++++-- module-transaction_test.go | 13 +++++++-- 4 files changed, 47 insertions(+), 9 deletions(-) diff --git a/cmd/pam-moduler/moduler.go b/cmd/pam-moduler/moduler.go index f065ae10..a9ab9614 100644 --- a/cmd/pam-moduler/moduler.go +++ b/cmd/pam-moduler/moduler.go @@ -68,6 +68,7 @@ var ( moduleBuildFlags = flag.String("build-flags", "", "comma-separated list of go build flags to use when generating the module") moduleBuildTags = flag.String("build-tags", "", "comma-separated list of build tags to use when generating the module") noMain = flag.Bool("no-main", false, "whether to add an empty main to generated file") + parallelConv = flag.Bool("parallel-conv", false, "whether to support performing PAM conversations in parallel") ) // Usage is a replacement usage function for the flags package. @@ -136,6 +137,7 @@ func main() { generateTags: generateTags, noMain: *noMain, typeName: *typeName, + parallelConv: *parallelConv, } // Print the header and package clause. @@ -168,6 +170,7 @@ type Generator struct { generateTags []string buildFlags []string noMain bool + parallelConv bool } func (g *Generator) Printf(format string, args ...interface{}) { @@ -185,6 +188,11 @@ func (g *Generator) generate() { buildTagsArg = fmt.Sprintf("-tags %s", strings.Join(g.generateTags, ",")) } + var transactionCreator = "NewModuleTransactionInvoker" + if g.parallelConv { + transactionCreator = "NewModuleTransactionInvokerParallelConv" + } + vFuncs := map[string]string{ "authenticate": "Authenticate", "setcred": "SetCred", @@ -247,7 +255,7 @@ func handlePamCall(pamh *C.pam_handle_t, flags C.int, argc C.int, return pam.Ignore } - mt := pam.NewModuleTransactionInvoker(pam.NativeHandle(pamh)) + mt := pam.%s(pam.NativeHandle(pamh)) ret, err := mt.InvokeHandler(moduleFunc, pam.Flags(flags), sliceFromArgv(argc, argv)) @@ -257,7 +265,7 @@ func handlePamCall(pamh *C.pam_handle_t, flags C.int, argc C.int, return ret } -`) +`, transactionCreator) for cName, goName := range vFuncs { g.Printf(` diff --git a/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module.go b/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module.go index 23cc2624..70867f04 100644 --- a/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module.go +++ b/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module.go @@ -1,4 +1,4 @@ -//go:generate go run github.com/msteinert/pam/cmd/pam-moduler -type integrationTesterModule +//go:generate go run github.com/msteinert/pam/cmd/pam-moduler -type integrationTesterModule -parallel-conv //go:generate go generate --skip="pam_module.go" package main diff --git a/module-transaction.go b/module-transaction.go index d61cfe58..2c442eb7 100644 --- a/module-transaction.go +++ b/module-transaction.go @@ -20,6 +20,7 @@ import ( "fmt" "runtime" "runtime/cgo" + "sync" "sync/atomic" "unsafe" ) @@ -51,6 +52,7 @@ type ModuleHandlerFunc func(ModuleTransaction, Flags, []string) error // ModuleTransaction is the module-side handle for a PAM transaction type moduleTransaction struct { transactionBase + convMutex *sync.Mutex } // ModuleHandler is an interface for objects that can be used to create @@ -71,10 +73,27 @@ type ModuleTransactionInvoker interface { InvokeHandler(handler ModuleHandlerFunc, flags Flags, args []string) (Status, error) } -// NewModuleTransactionInvoker allows initializing a transaction invoker from -// the module side. +// NewModuleTransaction allows initializing a transaction from the module side. +// Conversations using this transaction can be multi-thread, but this requires +// the application loading the module to support this, otherwise we may just +// break their assumptions. +func NewModuleTransactionParallelConv(handle NativeHandle) ModuleTransaction { + return &moduleTransaction{transactionBase{handle: handle}, nil} +} + +// NewModuleTransactionInvoker allows initializing a transaction invoker from the +// module side. func NewModuleTransactionInvoker(handle NativeHandle) ModuleTransactionInvoker { - return &moduleTransaction{transactionBase{handle: handle}} + return &moduleTransaction{transactionBase{handle: handle}, &sync.Mutex{}} +} + +// NewModuleTransactionInvokerParallelConv allows initializing a transaction invoker +// from the module side. +// Conversations using this transaction can be multi-thread, but this requires +// the application loading the module to support this, otherwise we may just +// break their assumptions. +func NewModuleTransactionInvokerParallelConv(handle NativeHandle) ModuleTransactionInvoker { + return &moduleTransaction{transactionBase{handle: handle}, nil} } func (m *moduleTransaction) InvokeHandler(handler ModuleHandlerFunc, @@ -467,6 +486,10 @@ func (m *moduleTransaction) startConvMultiImpl(iface moduleTransactionIface, } } + if m.convMutex != nil { + m.convMutex.Lock() + defer m.convMutex.Unlock() + } var cResponses *C.struct_pam_response if err := m.handlePamStatus( iface.startConv(conv, C.int(len(requests)), cMessages, &cResponses)); err != nil { diff --git a/module-transaction_test.go b/module-transaction_test.go index 627c915a..61568ba8 100644 --- a/module-transaction_test.go +++ b/module-transaction_test.go @@ -296,11 +296,9 @@ func Test_ModuleTransaction_InvokeHandler(t *testing.T) { } } -func Test_MockModuleTransaction(t *testing.T) { +func testMockModuleTransaction(t *testing.T, mt *moduleTransaction) { t.Parallel() - mt := NewModuleTransactionInvoker(nil).(*moduleTransaction) - tests := map[string]struct { testFunc func(mock *mockModuleTransaction) (any, error) mockExpectations mockModuleTransactionExpectations @@ -857,3 +855,12 @@ func Test_MockModuleTransaction(t *testing.T) { }) } } + +func Test_MockModuleTransaction(t *testing.T) { + testMockModuleTransaction(t, NewModuleTransactionInvoker(nil).(*moduleTransaction)) +} + +func Test_MockModuleTransactionParallelConv(t *testing.T) { + testMockModuleTransaction(t, + NewModuleTransactionInvokerParallelConv(nil).(*moduleTransaction)) +}