Skip to content

Commit

Permalink
fix: try to get clsid from server list and reg
Browse files Browse the repository at this point in the history
  • Loading branch information
huskar-t committed Jun 14, 2024
1 parent 3c352b9 commit 5cb0e67
Show file tree
Hide file tree
Showing 2 changed files with 178 additions and 9 deletions.
67 changes: 58 additions & 9 deletions opcserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"unsafe"

"github.com/huskar-t/opcda/com"
"golang.org/x/sys/windows/registry"

"golang.org/x/sys/windows"
)
Expand All @@ -32,15 +33,23 @@ func Connect(progID, node string) (opcServer *OPCServer, err error) {
if !com.IsLocal(node) {
location = com.CLSCTX_REMOTE_SERVER
}
iCatInfo, err := com.MakeCOMObjectEx(node, location, &com.CLSID_OpcServerList, &com.IID_IOPCServerList2)
if err != nil {
return nil, err
}
defer iCatInfo.Release()
sl := &com.IOPCServerList2{IUnknown: iCatInfo}
clsid, err := sl.CLSIDFromProgID(progID)
if err != nil {
return nil, err
var clsid *windows.GUID
if location == com.CLSCTX_LOCAL_SERVER {
id, err := windows.GUIDFromString(progID)
if err != nil {
return nil, err
}
clsid = &id
} else {
// try get clsid from server list
clsid, err = getClsIDFromServerList(progID, node, location)
if err != nil {
// try get clsid from windows reg
clsid, err = getClsIDFromReg(progID, node)
if err != nil {
return nil, err
}
}
}
iUnknownServer, err := com.MakeCOMObjectEx(node, location, clsid, &com.IID_IOPCServer)
if err != nil {
Expand Down Expand Up @@ -115,6 +124,46 @@ func Connect(progID, node string) (opcServer *OPCServer, err error) {
return opcServer, nil
}

func getClsIDFromServerList(progID, node string, location com.CLSCTX) (*windows.GUID, error) {
iCatInfo, err := com.MakeCOMObjectEx(node, location, &com.CLSID_OpcServerList, &com.IID_IOPCServerList2)
if err != nil {
return nil, err
}
defer iCatInfo.Release()
sl := &com.IOPCServerList2{IUnknown: iCatInfo}
clsid, err := sl.CLSIDFromProgID(progID)
if err != nil {
return nil, err
}
return clsid, nil
}

func getClsIDFromReg(progID, node string) (*windows.GUID, error) {
var clsid windows.GUID
var err error
hKey, err := registry.OpenRemoteKey(node, registry.CLASSES_ROOT)
if err != nil {
return nil, err
}
defer hKey.Close()
hProgIDKey, err := registry.OpenKey(hKey, progID, registry.READ)
if err != nil {
return nil, err
}
defer hProgIDKey.Close()
hClsidKey, err := registry.OpenKey(hProgIDKey, "CLSID", registry.READ)
if err != nil {
return nil, err
}
defer hClsidKey.Close()
clsidStr, _, err := hClsidKey.GetStringValue("")
if err != nil {
return nil, err
}
clsid, err = windows.GUIDFromString(clsidStr)
return &clsid, err
}

type ServerInfo struct {
ProgID string
ClsStr string
Expand Down
120 changes: 120 additions & 0 deletions opcserver_test.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
package opcda

import (
"fmt"
"testing"
"time"

"github.com/huskar-t/opcda/com"
"github.com/stretchr/testify/assert"
"golang.org/x/sys/windows"
"golang.org/x/sys/windows/svc"
"golang.org/x/sys/windows/svc/mgr"
)
Expand Down Expand Up @@ -356,3 +358,121 @@ func TestOPCServer_RegisterServerShutDown(t *testing.T) {
}
<-done
}

func Test_getClsIDFromReg(t *testing.T) {
id, err := windows.GUIDFromString(TestProgID)
if err != nil {
t.Fatal(err)
}
localNode, err := windows.ComputerName()
if err != nil {
t.Fatal(err)
}
type args struct {
progID string
node string
}

tests := []struct {
name string
args args
want *windows.GUID
wantErr assert.ErrorAssertionFunc
}{
{
name: "normal",
args: args{
progID: TestProgID,
node: localNode,
},
want: &id,
wantErr: assert.NoError,
},
{
name: "wrong node",
args: args{
progID: TestProgID,
node: "wrong",
},
want: nil,
wantErr: assert.Error,
},
{
name: "wrong progID",
args: args{
progID: "wrong",
node: localNode,
},
want: nil,
wantErr: assert.Error,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := getClsIDFromReg(tt.args.progID, tt.args.node)
if !tt.wantErr(t, err, fmt.Sprintf("getClsIDFromReg(%v, %v)", tt.args.progID, tt.args.node)) {
return
}
assert.Equalf(t, tt.want, got, "getClsIDFromReg(%v, %v)", tt.args.progID, tt.args.node)
})
}
}

func Test_getClsIDFromServerList(t *testing.T) {
id, err := windows.GUIDFromString(TestProgID)
if err != nil {
t.Fatal(err)
}
type args struct {
progID string
node string
location com.CLSCTX
}
tests := []struct {
name string
args args
want *windows.GUID
wantErr assert.ErrorAssertionFunc
}{
{
name: "Test with valid progID and node",
args: args{
progID: TestProgID,
node: TestHost,
location: com.CLSCTX_LOCAL_SERVER,
},
want: &id,
wantErr: assert.NoError,
},
{
name: "Test with invalid progID",
args: args{
progID: "InvalidProgID",
node: TestHost,
location: com.CLSCTX_LOCAL_SERVER,
},
want: nil,
wantErr: assert.Error,
},
{
name: "Test with invalid node",
args: args{
progID: TestProgID,
node: "InvalidNode",
location: com.CLSCTX_LOCAL_SERVER,
},
want: nil,
wantErr: assert.Error,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := getClsIDFromServerList(tt.args.progID, tt.args.node, tt.args.location)
if !tt.wantErr(t, err, fmt.Sprintf("getClsIDFromServerList(%v, %v, %v)", tt.args.progID, tt.args.node, tt.args.location)) {
return
}
assert.Equalf(t, tt.want, got, "getClsIDFromServerList(%v, %v, %v)", tt.args.progID, tt.args.node, tt.args.location)
})
}
}

0 comments on commit 5cb0e67

Please sign in to comment.