diff --git a/transport_tcp.go b/transport_tcp.go index 9e9a6c7..a0d93f4 100644 --- a/transport_tcp.go +++ b/transport_tcp.go @@ -79,7 +79,7 @@ func (t *TCPTransport) Exchange(ctx context.Context, message *dns.Msg) (*dns.Msg return nil, err } defer conn.Close() - err = writeMessage(conn, message) + err = writeMessage(conn, 0, message) if err != nil { return nil, err } @@ -110,12 +110,13 @@ func readMessage(reader io.Reader) (*dns.Msg, error) { return &message, err } -func writeMessage(writer io.Writer, message *dns.Msg) error { +func writeMessage(writer io.Writer, messageId uint16, message *dns.Msg) error { requestLen := message.Len() buffer := buf.NewSize(3 + requestLen) defer buffer.Release() common.Must(binary.Write(buffer, binary.BigEndian, uint16(requestLen))) exMessage := *message + exMessage.Id = messageId exMessage.Compress = true rawMessage, err := exMessage.PackBuffer(buffer.FreeBytes()) if err != nil { diff --git a/transport_tls.go b/transport_tls.go index 93eeead..41824aa 100644 --- a/transport_tls.go +++ b/transport_tls.go @@ -115,10 +115,8 @@ func (t *TLSTransport) Exchange(ctx context.Context, message *dns.Msg) (*dns.Msg } func (t *TLSTransport) exchange(message *dns.Msg, conn *tlsDNSConn) (*dns.Msg, error) { - messageId := message.Id conn.queryId++ - message.Id = conn.queryId - err := writeMessage(conn, message) + err := writeMessage(conn, conn.queryId, message) if err != nil { conn.Close() return nil, E.Cause(err, "write request") @@ -128,7 +126,6 @@ func (t *TLSTransport) exchange(message *dns.Msg, conn *tlsDNSConn) (*dns.Msg, e conn.Close() return nil, E.Cause(err, "read response") } - response.Id = messageId t.access.Lock() t.connections.PushBack(conn) t.access.Unlock()