Skip to content

Commit

Permalink
Merge pull request #9 from huskar-t/fix/connect
Browse files Browse the repository at this point in the history
fix: try to get clsid from server list and reg
  • Loading branch information
huskar-t authored Jun 20, 2024
2 parents 3c352b9 + 8add191 commit 89be54b
Show file tree
Hide file tree
Showing 8 changed files with 354 additions and 32 deletions.
2 changes: 1 addition & 1 deletion com/IOPCServerList.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func (sl *IOPCServerList2) Vtbl() *IOPCServerListVtbl {
return (*IOPCServerListVtbl)(unsafe.Pointer(sl.IUnknown.LpVtbl))
}

func (sl *IOPCServerList2) EnumClassesOfCateGories(rgcatidImpl []windows.GUID, rgcatidReq []windows.GUID) (ppenumClsid *IEnumGUID, err error) {
func (sl *IOPCServerList2) EnumClassesOfCategories(rgcatidImpl []windows.GUID, rgcatidReq []windows.GUID) (ppenumClsid *IEnumGUID, err error) {
var r0 uintptr
cImplemented := uint32(len(rgcatidImpl))
cRequired := uint32(len(rgcatidReq))
Expand Down
22 changes: 16 additions & 6 deletions com/com.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package com

import (
"errors"
"strings"
"syscall"
"unsafe"

Expand Down Expand Up @@ -184,16 +185,18 @@ func SysFreeString(v *uint16) (err error) {
}

func MakeCOMObjectEx(hostname string, serverLocation CLSCTX, requestedClass *windows.GUID, requestedInterface *windows.GUID) (*IUnknown, error) {
requestedServerInfo := COSERVERINFO{
PwszName: windows.StringToUTF16Ptr(hostname),
PAuthInfo: nil,
}
reqInterface := MULTI_QI{
PIID: requestedInterface,
PItf: nil,
Hr: 0,
}
err := CoCreateInstanceEx(requestedClass, nil, serverLocation, &requestedServerInfo, 1, &reqInterface)
var serverInfoPtr *COSERVERINFO = nil
if serverLocation != CLSCTX_LOCAL_SERVER {
serverInfoPtr = &COSERVERINFO{
PwszName: windows.StringToUTF16Ptr(hostname),
}
}
err := CoCreateInstanceEx(requestedClass, nil, serverLocation, serverInfoPtr, 1, &reqInterface)
if err != nil {
return nil, err
}
Expand All @@ -204,7 +207,14 @@ func MakeCOMObjectEx(hostname string, serverLocation CLSCTX, requestedClass *win
}

func IsLocal(host string) bool {
return host == "" || host == "localhost" || host == "127.0.0.1"
if host == "" || host == "localhost" || host == "127.0.0.1" {
return true
}
name, err := windows.ComputerName()
if err != nil {
return false
}
return strings.ToLower(name) == strings.ToLower(host)
}

// Initialize initialize COM with COINIT_MULTITHREADED
Expand Down
2 changes: 1 addition & 1 deletion opcbrowser.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func NewOPCBrowser(parent *OPCServer) (*OPCBrowser, error) {
var iBrowseServerAddressSpace *com.IUnknown
err := parent.iServer.QueryInterface(&com.IID_IOPCBrowseServerAddressSpace, unsafe.Pointer(&iBrowseServerAddressSpace))
if err != nil {
return nil, err
return nil, NewOPCWrapperError("query interface IOPCBrowseServerAddressSpace", err)
}
return &OPCBrowser{
iBrowseServerAddressSpace: &com.IOPCBrowseServerAddressSpace{IUnknown: iBrowseServerAddressSpace},
Expand Down
24 changes: 23 additions & 1 deletion opcerror.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package opcda

import "fmt"
import (
"fmt"
)

type OPCError struct {
ErrorCode int32
Expand Down Expand Up @@ -54,3 +56,23 @@ var (
OPCNotFound = uint32(0xC0040011)
OPCInvalidPID = uint32(0xC0040203)
)

type OPCWrapperError struct {
Err error
Info string
}

func (e *OPCWrapperError) Error() string {
if e.Err == nil {
return fmt.Sprintf("%s: <nil>", e.Info)
}
return fmt.Sprintf("%s: %s", e.Info, e.Err.Error())
}

func NewOPCWrapperError(info string, err error) error {
return &OPCWrapperError{Err: err, Info: info}
}

func (e *OPCWrapperError) Unwrap() error {
return e.Err
}
121 changes: 121 additions & 0 deletions opcerror_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package opcda

import (
"fmt"
"reflect"
"testing"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -51,3 +53,122 @@ func TestOPCError_Error(t *testing.T) {
})
}
}

func TestOPCWrapperError_Error(t *testing.T) {
type fields struct {
Err error
Info string
}
tests := []struct {
name string
fields fields
want string
}{
{
name: "Test with error and info",
fields: fields{
Err: fmt.Errorf("test error"),
Info: "test info",
},
want: "test info: test error",
},
{
name: "Test with error and no info",
fields: fields{
Err: fmt.Errorf("test error"),
Info: "",
},
want: ": test error",
},
{
name: "Test with no error and info",
fields: fields{
Err: nil,
Info: "test info",
},
want: "test info: <nil>",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
e := &OPCWrapperError{
Err: tt.fields.Err,
Info: tt.fields.Info,
}
if got := e.Error(); got != tt.want {
t.Errorf("OPCWrapperError.Error() = %v, want %v", got, tt.want)
}
})
}
}

func TestNewOPCWrapperError(t *testing.T) {
type args struct {
info string
err error
}
tests := []struct {
name string
args args
want *OPCWrapperError
}{
{
name: "Test with error and info",
args: args{
err: fmt.Errorf("test error"),
info: "test info",
},
want: &OPCWrapperError{
Err: fmt.Errorf("test error"),
Info: "test info",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := NewOPCWrapperError(tt.args.info, tt.args.err); !reflect.DeepEqual(got, tt.want) {
t.Errorf("NewOPCWrapperError() = %v, want %v", got, tt.want)
}
})
}
}

func TestOPCWrapperError_Unwrap(t *testing.T) {
type fields struct {
Err error
Info string
}
tests := []struct {
name string
fields fields
want error
}{
{
name: "Test with error",
fields: fields{
Err: fmt.Errorf("test error"),
Info: "test info",
},
want: fmt.Errorf("test error"),
},
{
name: "Test with no error",
fields: fields{
Err: nil,
Info: "test info",
},
want: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
e := &OPCWrapperError{
Err: tt.fields.Err,
Info: tt.fields.Info,
}
if got := e.Unwrap(); !reflect.DeepEqual(got, tt.want) {
t.Errorf("OPCWrapperError.Unwrap() = %v, want %v", got, tt.want)
}
})
}
}
8 changes: 4 additions & 4 deletions opcgroup.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,20 +46,20 @@ func NewOPCGroup(
var iUnknownSyncIO *com.IUnknown
err := iUnknown.QueryInterface(&com.IID_IOPCSyncIO, unsafe.Pointer(&iUnknownSyncIO))
if err != nil {
return nil, err
return nil, NewOPCWrapperError("query interface IOPCSyncIO", err)
}
var iUnknownAsyncIO2 *com.IUnknown
err = iUnknown.QueryInterface(&com.IID_IOPCAsyncIO2, unsafe.Pointer(&iUnknownAsyncIO2))
if err != nil {
iUnknownSyncIO.Release()
return nil, err
return nil, NewOPCWrapperError("query interface IOPCAsyncIO2", err)
}
var iUnknownItemMgt *com.IUnknown
err = iUnknown.QueryInterface(&com.IID_IOPCItemMgt, unsafe.Pointer(&iUnknownItemMgt))
if err != nil {
iUnknownSyncIO.Release()
iUnknownAsyncIO2.Release()
return nil, err
return nil, NewOPCWrapperError("query interface IOPCItemMgt", err)
}

o := &OPCGroup{
Expand Down Expand Up @@ -343,7 +343,7 @@ func (g *OPCGroup) advice() (err error) {
var iUnknownContainer *com.IUnknown
err = g.groupStateMgt.QueryInterface(&com.IID_IConnectionPointContainer, unsafe.Pointer(&iUnknownContainer))
if err != nil {
return
return NewOPCWrapperError("query interface IConnectionPointContainer", err)
}
defer func() {
if err != nil {
Expand Down
Loading

0 comments on commit 89be54b

Please sign in to comment.