Skip to content

Commit

Permalink
Change callback signatures to pass down context where applicable (#247)
Browse files Browse the repository at this point in the history
Changes a few server callback interface methods to pass down a context to propagate request information effectively. This closes #214
  • Loading branch information
jaronoff97 authored Jan 23, 2024
1 parent a669c09 commit 4d07a6a
Show file tree
Hide file tree
Showing 9 changed files with 104 additions and 72 deletions.
29 changes: 28 additions & 1 deletion client/clientimpl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -627,12 +627,24 @@ func TestAgentIdentification(t *testing.T) {
ulid.Timestamp(time.Now()), ulid.Monotonic(rand.New(rand.NewSource(0)), 0),
)
var rcvAgentInstanceUid atomic.Value
var sentInvalidId atomic.Bool
srv.OnMessage = func(msg *protobufs.AgentToServer) *protobufs.ServerToAgent {
rcvAgentInstanceUid.Store(msg.InstanceUid)
if sentInvalidId.Load() {
return &protobufs.ServerToAgent{
InstanceUid: msg.InstanceUid,
AgentIdentification: &protobufs.AgentIdentification{
// If we sent the invalid one first, send a valid one now
NewInstanceUid: newInstanceUid.String(),
},
}
}
sentInvalidId.Store(true)
return &protobufs.ServerToAgent{
InstanceUid: msg.InstanceUid,
AgentIdentification: &protobufs.AgentIdentification{
NewInstanceUid: newInstanceUid.String(),
// Start by sending an invalid id forcing an error.
NewInstanceUid: "",
},
}
}
Expand Down Expand Up @@ -660,6 +672,21 @@ func TestAgentIdentification(t *testing.T) {
// Send a dummy message
_ = client.SetAgentDescription(createAgentDescr())

// Verify that the old instance id was not overridden
eventually(
t,
func() bool {
instanceUid, ok := rcvAgentInstanceUid.Load().(string)
if !ok {
return false
}
return instanceUid == oldInstanceUid
},
)

// Send a dummy message again to get the _new_ id
_ = client.SetAgentDescription(createAgentDescr())

// When it was sent, the new instance uid should have been used, which should
// have been observed by the Server
eventually(
Expand Down
18 changes: 10 additions & 8 deletions client/internal/receivedprocessor.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,10 @@ func (r *receivedProcessor) ProcessReceivedMessage(ctx context.Context, msg *pro
}

if msg.AgentIdentification != nil {
err := r.rcvAgentIdentification(msg.AgentIdentification)
if err == nil {
err := r.rcvAgentIdentification(ctx, msg.AgentIdentification)
if err != nil {
r.logger.Errorf(ctx, "Failed to set agent ID: %v", err)
} else {
msgData.AgentIdentification = msg.AgentIdentification
}
}
Expand All @@ -146,7 +148,7 @@ func (r *receivedProcessor) ProcessReceivedMessage(ctx context.Context, msg *pro

err := msg.GetErrorResponse()
if err != nil {
r.processErrorResponse(err)
r.processErrorResponse(ctx, err)
}
}

Expand Down Expand Up @@ -203,21 +205,21 @@ func (r *receivedProcessor) rcvOpampConnectionSettings(ctx context.Context, sett
}
}

func (r *receivedProcessor) processErrorResponse(body *protobufs.ServerErrorResponse) {
func (r *receivedProcessor) processErrorResponse(ctx context.Context, body *protobufs.ServerErrorResponse) {
// TODO: implement this.
r.logger.Errorf(context.Background(), "received an error from server: %s", body.ErrorMessage)
r.logger.Errorf(ctx, "received an error from server: %s", body.ErrorMessage)
}

func (r *receivedProcessor) rcvAgentIdentification(agentId *protobufs.AgentIdentification) error {
func (r *receivedProcessor) rcvAgentIdentification(ctx context.Context, agentId *protobufs.AgentIdentification) error {
if agentId.NewInstanceUid == "" {
err := errors.New("empty instance uid is not allowed")
r.logger.Debugf(context.Background(), err.Error())
r.logger.Debugf(ctx, err.Error())
return err
}

err := r.sender.SetInstanceUid(agentId.NewInstanceUid)
if err != nil {
r.logger.Errorf(context.Background(), "Error while setting instance uid: %v", err)
r.logger.Errorf(ctx, "Error while setting instance uid: %v", err)
return err
}

Expand Down
16 changes: 8 additions & 8 deletions internal/examples/agent/agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,9 @@ func (agent *Agent) connect() error {
return nil
}

func (agent *Agent) disconnect() {
agent.logger.Debugf(context.Background(), "Disconnecting from server...")
agent.opampClient.Stop(context.Background())
func (agent *Agent) disconnect(ctx context.Context) {
agent.logger.Debugf(ctx, "Disconnecting from server...")
agent.opampClient.Stop(ctx)
}

func (agent *Agent) createAgentIdentity() {
Expand Down Expand Up @@ -209,8 +209,8 @@ func (agent *Agent) createAgentIdentity() {
}
}

func (agent *Agent) updateAgentIdentity(instanceId ulid.ULID) {
agent.logger.Debugf(context.Background(), "Agent identify is being changed from id=%v to id=%v",
func (agent *Agent) updateAgentIdentity(ctx context.Context, instanceId ulid.ULID) {
agent.logger.Debugf(ctx, "Agent identify is being changed from id=%v to id=%v",
agent.instanceId.String(),
instanceId.String())
agent.instanceId = instanceId
Expand Down Expand Up @@ -463,13 +463,13 @@ func (agent *Agent) onMessage(ctx context.Context, msg *types.MessageData) {
if err != nil {
agent.logger.Errorf(ctx, err.Error())
}
agent.updateAgentIdentity(newInstanceId)
agent.updateAgentIdentity(ctx, newInstanceId)
}

if configChanged {
err := agent.opampClient.UpdateEffectiveConfig(ctx)
if err != nil {
agent.logger.Errorf(context.Background(), err.Error())
agent.logger.Errorf(ctx, err.Error())
}
}

Expand All @@ -486,7 +486,7 @@ func (agent *Agent) onMessage(ctx context.Context, msg *types.MessageData) {
func (agent *Agent) tryChangeOpAMPCert(ctx context.Context, cert *tls.Certificate) {
agent.logger.Debugf(ctx, "Reconnecting to verify offered client certificate.\n")

agent.disconnect()
agent.disconnect(ctx)

agent.opampClientCert = cert
if err := agent.connect(); err != nil {
Expand Down
2 changes: 1 addition & 1 deletion internal/examples/server/opampsrv/opampsrv.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ func (srv *Server) onDisconnect(conn types.Connection) {
srv.agents.RemoveConnection(conn)
}

func (srv *Server) onMessage(conn types.Connection, msg *protobufs.AgentToServer) *protobufs.ServerToAgent {
func (srv *Server) onMessage(ctx context.Context, conn types.Connection, msg *protobufs.AgentToServer) *protobufs.ServerToAgent {
instanceId := data.InstanceId(msg.InstanceUid)

agent := srv.agents.FindOrCreateAgent(instanceId, conn)
Expand Down
14 changes: 7 additions & 7 deletions internal/examples/supervisor/supervisor/supervisor.go
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ service:
s.agentConfigOwnMetricsSection.Store(cfg)

// Need to recalculate the Agent config so that the metric config is included in it.
configChanged, err := s.recalcEffectiveConfig()
configChanged, err := s.recalcEffectiveConfig(ctx)
if err != nil {
return
}
Expand All @@ -327,7 +327,7 @@ service:
// composeEffectiveConfig composes the effective config from multiple sources:
// 1) the remote config from OpAMP Server, 2) the own metrics config section,
// 3) the local override config that is hard-coded in the Supervisor.
func (s *Supervisor) composeEffectiveConfig(config *protobufs.AgentRemoteConfig) (configChanged bool, err error) {
func (s *Supervisor) composeEffectiveConfig(ctx context.Context, config *protobufs.AgentRemoteConfig) (configChanged bool, err error) {
var k = koanf.New(".")

// Begin with empty config. We will merge received configs on top of it.
Expand Down Expand Up @@ -387,7 +387,7 @@ func (s *Supervisor) composeEffectiveConfig(config *protobufs.AgentRemoteConfig)
newEffectiveConfig := string(effectiveConfigBytes)
configChanged = false
if s.effectiveConfig.Load().(string) != newEffectiveConfig {
s.logger.Debugf(context.Background(), "Effective config changed.")
s.logger.Debugf(ctx, "Effective config changed.")
s.effectiveConfig.Store(newEffectiveConfig)
configChanged = true
}
Expand All @@ -397,11 +397,11 @@ func (s *Supervisor) composeEffectiveConfig(config *protobufs.AgentRemoteConfig)

// Recalculate the Agent's effective config and if the config changes signal to the
// background goroutine that the config needs to be applied to the Agent.
func (s *Supervisor) recalcEffectiveConfig() (configChanged bool, err error) {
func (s *Supervisor) recalcEffectiveConfig(ctx context.Context) (configChanged bool, err error) {

configChanged, err = s.composeEffectiveConfig(s.remoteConfig)
configChanged, err = s.composeEffectiveConfig(ctx, s.remoteConfig)
if err != nil {
s.logger.Errorf(context.Background(), "Error composing effective config. Ignoring received config: %v", err)
s.logger.Errorf(ctx, "Error composing effective config. Ignoring received config: %v", err)
return configChanged, err
}

Expand Down Expand Up @@ -553,7 +553,7 @@ func (s *Supervisor) onMessage(ctx context.Context, msg *types.MessageData) {
s.logger.Debugf(ctx, "Received remote config from server, hash=%x.", s.remoteConfig.ConfigHash)

var err error
configChanged, err = s.recalcEffectiveConfig()
configChanged, err = s.recalcEffectiveConfig(ctx)
if err != nil {
s.opampClient.SetRemoteConfigStatus(&protobufs.RemoteConfigStatus{
LastRemoteConfigHash: msg.RemoteConfig.ConfigHash,
Expand Down
13 changes: 7 additions & 6 deletions server/callbacks.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package server

import (
"context"
"net/http"

"github.com/open-telemetry/opamp-go/protobufs"
Expand All @@ -27,25 +28,25 @@ func (c CallbacksStruct) OnConnecting(request *http.Request) types.ConnectionRes
// ConnectionCallbacksStruct is a struct that implements ConnectionCallbacks interface and allows
// to override only the methods that are needed.
type ConnectionCallbacksStruct struct {
OnConnectedFunc func(conn types.Connection)
OnMessageFunc func(conn types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent
OnConnectedFunc func(ctx context.Context, conn types.Connection)
OnMessageFunc func(ctx context.Context, conn types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent
OnConnectionCloseFunc func(conn types.Connection)
}

var _ types.ConnectionCallbacks = (*ConnectionCallbacksStruct)(nil)

// OnConnected implements ConnectionCallbacks.OnConnected.
func (c ConnectionCallbacksStruct) OnConnected(conn types.Connection) {
func (c ConnectionCallbacksStruct) OnConnected(ctx context.Context, conn types.Connection) {
if c.OnConnectedFunc != nil {
c.OnConnectedFunc(conn)
c.OnConnectedFunc(ctx, conn)
}
}

// OnMessage implements ConnectionCallbacks.OnMessage.
// If OnMessageFunc is nil then it will send an empty response to the agent
func (c ConnectionCallbacksStruct) OnMessage(conn types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent {
func (c ConnectionCallbacksStruct) OnMessage(ctx context.Context, conn types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent {
if c.OnMessageFunc != nil {
return c.OnMessageFunc(conn, message)
return c.OnMessageFunc(ctx, conn, message)
} else {
// We will send an empty response since there is no user-defined callback to handle it.
return &protobufs.ServerToAgent{
Expand Down
49 changes: 25 additions & 24 deletions server/serverimpl.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,16 +179,16 @@ func (s *server) httpHandler(w http.ResponseWriter, req *http.Request) {
// No, it is a WebSocket. Upgrade it.
conn, err := s.wsUpgrader.Upgrade(w, req, nil)
if err != nil {
s.logger.Errorf(context.Background(), "Cannot upgrade HTTP connection to WebSocket: %v", err)
s.logger.Errorf(req.Context(), "Cannot upgrade HTTP connection to WebSocket: %v", err)
return
}

// Return from this func to reduce memory usage.
// Handle the connection on a separate goroutine.
go s.handleWSConnection(conn, connectionCallbacks)
go s.handleWSConnection(req.Context(), conn, connectionCallbacks)
}

func (s *server) handleWSConnection(wsConn *websocket.Conn, connectionCallbacks serverTypes.ConnectionCallbacks) {
func (s *server) handleWSConnection(reqCtx context.Context, wsConn *websocket.Conn, connectionCallbacks serverTypes.ConnectionCallbacks) {
agentConn := wsConnection{wsConn: wsConn, connMutex: &sync.Mutex{}}

defer func() {
Expand All @@ -206,43 +206,44 @@ func (s *server) handleWSConnection(wsConn *websocket.Conn, connectionCallbacks
}()

if connectionCallbacks != nil {
connectionCallbacks.OnConnected(agentConn)
connectionCallbacks.OnConnected(reqCtx, agentConn)
}

// Loop until fail to read from the WebSocket connection.
for {
msgContext := context.Background()
// Block until the next message can be read.
mt, bytes, err := wsConn.ReadMessage()
mt, msgBytes, err := wsConn.ReadMessage()
if err != nil {
if !websocket.IsUnexpectedCloseError(err) {
s.logger.Errorf(context.Background(), "Cannot read a message from WebSocket: %v", err)
s.logger.Errorf(msgContext, "Cannot read a message from WebSocket: %v", err)
break
}
// This is a normal closing of the WebSocket connection.
s.logger.Debugf(context.Background(), "Agent disconnected: %v", err)
s.logger.Debugf(msgContext, "Agent disconnected: %v", err)
break
}
if mt != websocket.BinaryMessage {
s.logger.Errorf(context.Background(), "Received unexpected message type from WebSocket: %v", mt)
s.logger.Errorf(msgContext, "Received unexpected message type from WebSocket: %v", mt)
continue
}

// Decode WebSocket message as a Protobuf message.
var request protobufs.AgentToServer
err = internal.DecodeWSMessage(bytes, &request)
err = internal.DecodeWSMessage(msgBytes, &request)
if err != nil {
s.logger.Errorf(context.Background(), "Cannot decode message from WebSocket: %v", err)
s.logger.Errorf(msgContext, "Cannot decode message from WebSocket: %v", err)
continue
}

if connectionCallbacks != nil {
response := connectionCallbacks.OnMessage(agentConn, &request)
response := connectionCallbacks.OnMessage(msgContext, agentConn, &request)
if response.InstanceUid == "" {
response.InstanceUid = request.InstanceUid
}
err = agentConn.Send(context.Background(), response)
err = agentConn.Send(msgContext, response)
if err != nil {
s.logger.Errorf(context.Background(), "Cannot send message to WebSocket: %v", err)
s.logger.Errorf(msgContext, "Cannot send message to WebSocket: %v", err)
}
}
}
Expand Down Expand Up @@ -286,18 +287,18 @@ func compressGzip(data []byte) ([]byte, error) {
}

func (s *server) handlePlainHTTPRequest(req *http.Request, w http.ResponseWriter, connectionCallbacks serverTypes.ConnectionCallbacks) {
bytes, err := s.readReqBody(req)
bodyBytes, err := s.readReqBody(req)
if err != nil {
s.logger.Debugf(context.Background(), "Cannot read HTTP body: %v", err)
s.logger.Debugf(req.Context(), "Cannot read HTTP body: %v", err)
w.WriteHeader(http.StatusBadRequest)
return
}

// Decode the message as a Protobuf message.
var request protobufs.AgentToServer
err = proto.Unmarshal(bytes, &request)
err = proto.Unmarshal(bodyBytes, &request)
if err != nil {
s.logger.Debugf(context.Background(), "Cannot decode message from HTTP Body: %v", err)
s.logger.Debugf(req.Context(), "Cannot decode message from HTTP Body: %v", err)
w.WriteHeader(http.StatusBadRequest)
return
}
Expand All @@ -311,7 +312,7 @@ func (s *server) handlePlainHTTPRequest(req *http.Request, w http.ResponseWriter
return
}

connectionCallbacks.OnConnected(agentConn)
connectionCallbacks.OnConnected(req.Context(), agentConn)

defer func() {
// Indicate via the callback that the OpAMP Connection is closed. From OpAMP
Expand All @@ -321,15 +322,15 @@ func (s *server) handlePlainHTTPRequest(req *http.Request, w http.ResponseWriter
connectionCallbacks.OnConnectionClose(agentConn)
}()

response := connectionCallbacks.OnMessage(agentConn, &request)
response := connectionCallbacks.OnMessage(req.Context(), agentConn, &request)

// Set the InstanceUid if it is not set by the callback.
if response.InstanceUid == "" {
response.InstanceUid = request.InstanceUid
}

// Marshal the response.
bytes, err = proto.Marshal(response)
bodyBytes, err = proto.Marshal(response)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
return
Expand All @@ -338,17 +339,17 @@ func (s *server) handlePlainHTTPRequest(req *http.Request, w http.ResponseWriter
// Send the response.
w.Header().Set(headerContentType, contentTypeProtobuf)
if req.Header.Get(headerAcceptEncoding) == contentEncodingGzip {
bytes, err = compressGzip(bytes)
bodyBytes, err = compressGzip(bodyBytes)
if err != nil {
s.logger.Errorf(context.Background(), "Cannot compress response: %v", err)
s.logger.Errorf(req.Context(), "Cannot compress response: %v", err)
w.WriteHeader(http.StatusInternalServerError)
return
}
w.Header().Set(headerContentEncoding, contentEncodingGzip)
}
_, err = w.Write(bytes)
_, err = w.Write(bodyBytes)

if err != nil {
s.logger.Debugf(context.Background(), "Cannot send HTTP response: %v", err)
s.logger.Debugf(req.Context(), "Cannot send HTTP response: %v", err)
}
}
Loading

0 comments on commit 4d07a6a

Please sign in to comment.