Skip to content

Commit 17dea2c

Browse files
authored
Simplify connection handling (#68)
1 parent 6020765 commit 17dea2c

File tree

1 file changed

+30
-45
lines changed

1 file changed

+30
-45
lines changed

rscp/client.go

Lines changed: 30 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,13 @@ package rscp
22

33
import (
44
"crypto/cipher"
5+
"errors"
56
"fmt"
67
"io"
78
"math"
89
"net"
910
"time"
1011

11-
"errors"
12-
1312
"github.com/azihsoyn/rijndael256"
1413
"github.com/sirupsen/logrus"
1514
)
@@ -20,7 +19,6 @@ var Log = logrus.New()
2019
type Client struct {
2120
config ClientConfig
2221
connectionString string
23-
isConnected bool
2422
isAuthenticated bool
2523
conn net.Conn
2624
encrypter cipher.BlockMode
@@ -34,15 +32,15 @@ func NewClient(config ClientConfig) (*Client, error) {
3432
if err := config.check(); err != nil {
3533
return nil, err
3634
}
35+
3736
key := createAESKey(config.Key)
3837
initIV := newIV()
3938
cipherBlock, _ := rijndael256.NewCipher(key[:]) // implementation does not return an error
39+
4040
// Intitialize the Client structure.
4141
c := &Client{
4242
config: config,
4343
connectionString: fmt.Sprintf("%s:%d", config.Address, config.Port),
44-
isConnected: false,
45-
isAuthenticated: false,
4644
encrypter: cipher.NewCBCEncrypter(cipherBlock, initIV[:]),
4745
decrypter: cipher.NewCBCDecrypter(cipherBlock, initIV[:]),
4846
}
@@ -54,11 +52,8 @@ func (c *Client) send(messages []Message) error {
5452
if err := validateRequests(messages); err != nil {
5553
return err
5654
}
57-
var (
58-
msg []byte
59-
err error
60-
)
61-
if msg, err = Write(&c.encrypter, messages, c.config.UseChecksum.(bool)); err != nil {
55+
msg, err := Write(&c.encrypter, messages, c.config.UseChecksum.(bool))
56+
if err != nil {
6257
return err
6358
}
6459
if err := c.conn.SetWriteDeadline(time.Now().Add(c.config.SendTimeout)); err != nil {
@@ -84,11 +79,13 @@ func (c *Client) receive() ([]Message, error) {
8479

8580
for i, data := 0, make([]byte, uint32(RSCP_CRYPT_BLOCK_SIZE)*uint32(c.config.ReceiveBufferBlockSize)); ; {
8681
var err error
82+
8783
if i, err = c.conn.Read(data); err != nil {
8884
return nil, fmt.Errorf("error during receive response: %w", err)
8985
} else if i == 0 {
9086
return nil, ErrRscpInvalidFrameLength
9187
}
88+
9289
switch m, err = Read(&c.decrypter, &buf, &crcFlag, &frameSize, &dataSize, data[:i]); {
9390
case errors.Is(err, ErrRscpInvalidFrameLength):
9491
// frame not complete
@@ -105,17 +102,15 @@ func (c *Client) receive() ([]Message, error) {
105102
// connect create connection
106103
func (c *Client) connect() error {
107104
Log.Infof("Connecting to %s", c.connectionString)
108-
var (
109-
conn net.Conn
110-
err error
111-
)
112-
if conn, err = net.DialTimeout("tcp", c.connectionString, c.config.ConnectionTimeout); err != nil {
113-
c.isConnected = false
105+
106+
conn, err := net.DialTimeout("tcp", c.connectionString, c.config.ConnectionTimeout)
107+
if err != nil {
114108
return err
115109
}
110+
111+
Log.Infof("successfully connected to %s", conn.RemoteAddr())
116112
c.conn = conn
117-
c.isConnected = true
118-
Log.Infof("successfully connected to %s", c.conn.RemoteAddr())
113+
119114
return nil
120115
}
121116

@@ -141,20 +136,20 @@ func (c *Client) authenticate() error {
141136
if orgLogLevel < RequiredAuthLogLevel {
142137
Log.SetLevel(orgLogLevel)
143138
}
144-
var (
145-
messages []Message
146-
err error
147-
)
148-
if messages, err = c.receive(); err != nil {
139+
140+
messages, err := c.receive()
141+
if err != nil {
149142
if errors.Is(err, io.EOF) {
150143
Log.Warnf("Hint: EOF during authentification usually is due a wrong rscp key")
151144
}
152145
return fmt.Errorf("authentication error: %w", err)
153146
}
147+
154148
if messages[0].Tag != RSCP_AUTHENTICATION {
155149
c.isAuthenticated = false
156150
return fmt.Errorf("authentication failed: %+v", messages[0])
157151
}
152+
158153
switch v := messages[0].Value.(type) {
159154
default:
160155
c.isAuthenticated = false
@@ -177,38 +172,35 @@ func (c *Client) authenticate() error {
177172
}
178173

179174
// Disconnect the client
180-
func (c *Client) Disconnect() error {
175+
func (c *Client) Disconnect() (err error) {
181176
c.isAuthenticated = false
182-
c.isConnected = false
183177

184-
if c.isConnected && c.conn != nil {
185-
if err := c.conn.Close(); err != nil {
186-
return err
187-
}
178+
if c.conn != nil {
179+
err = c.conn.Close()
180+
c.conn = nil
181+
Log.Info("disconnected")
188182
}
189-
Log.Info("disconnected")
190-
return nil
183+
184+
return err
191185
}
192186

193187
// Send a message and return the response.
194188
//
195189
// connects and authenticates the first time used.
196190
func (c *Client) Send(request Message) (*Message, error) {
197-
var (
198-
responses []Message
199-
err error
200-
)
201-
if responses, err = (c.SendMultiple([]Message{request})); err != nil {
191+
responses, err := c.SendMultiple([]Message{request})
192+
if err != nil {
202193
return nil, err
203194
}
195+
204196
return &responses[0], nil
205197
}
206198

207199
// Send multiple messages in one round-trip and return the response.
208200
//
209201
// connects and authenticates the first time used.
210202
func (c *Client) SendMultiple(requests []Message) ([]Message, error) {
211-
if !c.isConnected {
203+
if c.conn == nil {
212204
if err := c.connect(); err != nil {
213205
return nil, err
214206
}
@@ -221,12 +213,5 @@ func (c *Client) SendMultiple(requests []Message) ([]Message, error) {
221213
if err := c.send(requests); err != nil {
222214
return nil, err
223215
}
224-
var (
225-
responses []Message
226-
err error
227-
)
228-
if responses, err = c.receive(); err != nil {
229-
return nil, err
230-
}
231-
return responses, nil
216+
return c.receive()
232217
}

0 commit comments

Comments
 (0)