Skip to content

Commit

Permalink
feat(python): fastapi plugins for UT #540
Browse files Browse the repository at this point in the history
- rename lib name for nameing collision
- testcase for ut

close #540
  • Loading branch information
eeliu committed Oct 23, 2023
1 parent 02b8c7d commit fc55153
Show file tree
Hide file tree
Showing 16 changed files with 187 additions and 86 deletions.
12 changes: 8 additions & 4 deletions common/include/common.h.in
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,6 @@ typedef enum {
E_UTEST = 0x4
} AGENT_FLAG;

typedef enum { E_INVALID_NODE = -1, E_ROOT_NODE = 0 } NodeID;

#pragma pack(1)
typedef struct {
uint32_t type;
Expand Down Expand Up @@ -87,9 +85,15 @@ typedef enum {
/**
* @brief at present only root checking
*/
typedef enum { E_LOC_CURRENT = 0x0, E_LOC_ROOT = 0x1 } E_NODE_LOC;

#define PINPOINT_C_AGENT_API_VERSION "@PROJECT_VERSION@"
typedef int NodeID;
typedef NodeID E_NODE_LOC;
static const NodeID E_INVALID_NODE = -1;
static const NodeID E_ROOT_NODE = 0;
static const E_NODE_LOC E_LOC_CURRENT = 0x0;
static const E_NODE_LOC E_LOC_ROOT = 0x1;

#define PINPOINT_C_AGENT_API_VERSION "0.4.23"

/**
* @brief change logs
Expand Down
39 changes: 19 additions & 20 deletions plugins/PY/pinpointPy/Fastapi/AsyCommon.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
# ------------------------------------------------------------------------------



from ast import Assert
import asyncio

from starlette_context import context
Expand All @@ -29,67 +27,68 @@

class AsynPinTrace(object):

def __init__(self,name):
def __init__(self, name):
self.name = name

def getCurrentId(self):
id = context['_pinpoint_id_']
if not id:
raise 'not found traceId'
else:
return id
return id

def onBefore(self,parentId,*args, **kwargs):
def onBefore(self, parentId, *args, **kwargs):
traceId = pinpoint.with_trace(parentId)
# update global id
context['_pinpoint_id_'] = traceId
return traceId,args,kwargs
return traceId, args, kwargs

@staticmethod
def isSample(*args, **kwargs):
try:
parentid = context.get('_pinpoint_id_',0)
parentid = context.get('_pinpoint_id_', 0)
if parentid == 0:
return False,None
return True,parentid
return False, None
return True, parentid
except Exception as e:
return False,None
return False, None

@classmethod
def _isSample(cls,*args, **kwargs):
def _isSample(cls, *args, **kwargs):
return cls.isSample(*args, **kwargs)

def onEnd(self,parentId,ret):
def onEnd(self, parentId, ret):
parentId = pinpoint.end_trace(parentId)
context['_pinpoint_id_'] = parentId

def onException(self,traceId,e):
def onException(self, traceId, e):
raise NotImplementedError()

def __call__(self, func):
self.func_name=func.__name__
self.func_name = func.__name__

async def pinpointTrace(*args, **kwargs):
ret = None
sampled,parentId = self._isSample(args, kwargs)
sampled, parentId = self._isSample(args, kwargs)
if not sampled:
return await func(*args, **kwargs)

traceId,args,kwargs = self.onBefore(parentId,*args, **kwargs)
traceId, args, kwargs = self.onBefore(parentId, *args, **kwargs)
try:
ret = await func(*args, **kwargs)
return ret
except Exception as e:
self.onException(traceId,e)
self.onException(traceId, e)
raise e
finally:
self.onEnd(traceId,ret)
self.onEnd(traceId, ret)

return pinpointTrace

def getFuncUniqueName(self):
return self.name


if __name__ == '__main__':

@AsynPinTrace('main')
Expand All @@ -101,4 +100,4 @@ async def run(i):
await run(i-1)

asyncio.run(run(2))
asyncio.run(run(2))
asyncio.run(run(2))
33 changes: 21 additions & 12 deletions plugins/PY/pinpointPy/Fastapi/FastAPIRequestPlugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,33 @@
# ------------------------------------------------------------------------------


from pinpointPy.Fastapi.AsyRequestPlugin import *
from pinpointPy import Defines
from pinpointPy import pinpoint
from pinpointPy.Fastapi.AsyRequestPlugin import AsyRequestPlugin
from pinpointPy import Defines, pinpoint
import sys


class FastAPIRequestPlugin(AsyRequestPlugin):
def __init__(self, name):
super().__init__(name)

def onBefore(self,parentId,*args, **kwargs):
traceId,args,kwargs=super().onBefore(parentId,*args, **kwargs)
request = args[0].scope
pinpoint.add_trace_header(Defines.PP_INTERCEPTOR_NAME, 'fastapi-middleware',traceId)
def onBefore(self, parentId, *args, **kwargs):
traceId, args, kwargs = super().onBefore(parentId, *args, **kwargs)
request = args[0]
pinpoint.add_trace_header(
Defines.PP_INTERCEPTOR_NAME, 'fastapi-middleware', traceId)
pinpoint.add_trace_header(Defines.PP_REQ_URI, request["path"], traceId)
pinpoint.add_trace_header(Defines.PP_REQ_CLIENT, request["client"][0], traceId)
pinpoint.add_trace_header(Defines.PP_REQ_SERVER, request["server"][0] + ":" + str(request["server"][1]), traceId)
return traceId,args,kwargs
pinpoint.add_trace_header(
Defines.PP_REQ_CLIENT, f'{request.client.host}:{request.client.port}', traceId)
pinpoint.add_trace_header(
Defines.PP_REQ_SERVER, request.base_url.hostname, traceId)
self.request = request
return traceId, args, kwargs

def onEnd(self,traceId, ret):
return super().onEnd(traceId,ret)
def onEnd(self, traceId, response):
ut = self.request.scope['root_path'] + self.request.scope['route'].path
pinpoint.add_trace_header(Defines.PP_URL_TEMPLATED, ut, traceId)

if 'unittest' in sys.modules.keys():
response.headers["UT"] = ut

return super().onEnd(traceId, response)
6 changes: 3 additions & 3 deletions plugins/PY/pinpointPy/Fastapi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import importlib
from pinpointPy.Fastapi.PinTranscation import PinTransaction, PinStarlettePlugin
from pinpointPy.Fastapi.AsyCommonPlugin import CommonPlugin
from pinpointPy.Fastapi.middleware import PinPointMiddleWare
from pinpointPy.Fastapi.middleware import PinPointMiddleWare, FastAPIRequestPlugin
from pinpointPy.Common import PinHeader, GenPinHeader


Expand All @@ -38,7 +38,7 @@ def asyn_monkey_patch_for_pinpoint(AioRedis=True, MotorMongo=True, httpx=True):
__monkey_patch(aioredis=AioRedis, MotorMongo=MotorMongo, httpx=httpx)


__version__ = '0.0.1'
__version__ = '0.0.2'
__author__ = 'liu.mingyi@navercorp.com'
__all__ = ['asyn_monkey_patch_for_pinpoint', 'PinPointMiddleWare',
__all__ = ['asyn_monkey_patch_for_pinpoint', 'FastAPIRequestPlugin', 'PinPointMiddleWare',
'CommonPlugin', 'PinTransaction', 'PinHeader', 'GenPinHeader', 'PinStarlettePlugin']
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,24 @@

# create by eelu

from pinpointPy.Interceptor import Interceptor,intercept_once
from pinpointPy.Interceptor import Interceptor, intercept_once
from pinpointPy import logger


@intercept_once
def monkey_patch():
try:
from httpx import AsyncClient
from _httpx import AsyncClient
from .httpxPlugins import HttpxRequestPlugins
Interceptors = [
Interceptor(AsyncClient, 'request',HttpxRequestPlugins)
Interceptor(AsyncClient, 'request', HttpxRequestPlugins)
]

for interceptor in Interceptors:
interceptor.enable()

except ImportError as e:
# do nothing
print(e)
logger.debug(f"import httpx:{e}")


__all__=['monkey_patch']
__all__ = ['monkey_patch']
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,11 @@
from os import stat
from pickle import FALSE
from random import sample
from .. import AsyCommon
from ... import pinpoint
from ... import Defines
from ... import Helper

from pinpointPy.Fastapi import AsyCommon
from pinpointPy import pinpoint, Defines, Helper
from urllib.parse import urlparse


class HttpxRequestPlugins(AsyCommon.AsynPinTrace):

def __init__(self, name):
Expand All @@ -40,46 +38,50 @@ def isSample(*args, **kwargs):
if not root, no trace
:return:
'''
sampled,parentId= AsyCommon.AsynPinTrace.isSample(*args,**kwargs)
sampled, parentId = AsyCommon.AsynPinTrace.isSample(*args, **kwargs)
if not sampled:
return False,None
return False, None

url = args[0][2]
target = urlparse(url).netloc
if "headers" not in kwargs or not kwargs['headers']:
kwargs["headers"] = {}

if pinpoint.get_context(Defines.PP_HEADER_PINPOINT_SAMPLED,parentId) == "s1":
Helper.generatePinpointHeader(target, kwargs['headers'],parentId)
return True,parentId
if pinpoint.get_context(Defines.PP_HEADER_PINPOINT_SAMPLED, parentId) == "s1":
Helper.generatePinpointHeader(target, kwargs['headers'], parentId)
return True, parentId
else:
kwargs['headers'][Defines.PP_HEADER_PINPOINT_SAMPLED] = Defines.PP_NOT_SAMPLED
return False ,None
return False, None

def onBefore(self,parentId, *args, **kwargs):
def onBefore(self, parentId, *args, **kwargs):
url = args[2]
target = urlparse(url).netloc
traceId,args,kwargs = super().onBefore(parentId,*args, **kwargs)
traceId, args, kwargs = super().onBefore(parentId, *args, **kwargs)
###############################################################
pinpoint.add_trace_header(Defines.PP_INTERCEPTOR_NAME, self.getFuncUniqueName(),traceId)
pinpoint.add_trace_header(Defines.PP_SERVER_TYPE, Defines.PP_REMOTE_METHOD,traceId)
pinpoint.add_trace_header_v2(Defines.PP_ARGS, url,traceId)
pinpoint.add_trace_header_v2(Defines.PP_HTTP_URL, url,traceId)
pinpoint.add_trace_header(Defines.PP_DESTINATION, target,traceId)
pinpoint.add_trace_header(
Defines.PP_INTERCEPTOR_NAME, self.getFuncUniqueName(), traceId)
pinpoint.add_trace_header(
Defines.PP_SERVER_TYPE, Defines.PP_REMOTE_METHOD, traceId)
pinpoint.add_trace_header_v2(Defines.PP_ARGS, url, traceId)
pinpoint.add_trace_header_v2(Defines.PP_HTTP_URL, url, traceId)
pinpoint.add_trace_header(Defines.PP_DESTINATION, target, traceId)
###############################################################
return traceId,args, kwargs
return traceId, args, kwargs

def onEnd(self,traceId, ret):
def onEnd(self, traceId, ret):
###############################################################
pinpoint.add_trace_header(Defines.PP_NEXT_SPAN_ID, pinpoint.get_context(Defines.PP_NEXT_SPAN_ID,traceId),traceId)
pinpoint.add_trace_header_v2(Defines.PP_HTTP_STATUS_CODE, str(ret.status_code),traceId)
pinpoint.add_trace_header_v2(Defines.PP_RETURN, str(ret),traceId)
pinpoint.add_trace_header(Defines.PP_NEXT_SPAN_ID, pinpoint.get_context(
Defines.PP_NEXT_SPAN_ID, traceId), traceId)
pinpoint.add_trace_header_v2(
Defines.PP_HTTP_STATUS_CODE, str(ret.status_code), traceId)
pinpoint.add_trace_header_v2(Defines.PP_RETURN, str(ret), traceId)
###############################################################
super().onEnd(traceId,ret)
super().onEnd(traceId, ret)
return ret

def onException(self,traceId, e):
pinpoint.add_trace_header(Defines.PP_ADD_EXCEPTION, str(e),traceId)
def onException(self, traceId, e):
pinpoint.add_trace_header(Defines.PP_ADD_EXCEPTION, str(e), traceId)

def get_arg(self, *args, **kwargs):
args_tmp = {}
Expand All @@ -92,4 +94,4 @@ def get_arg(self, *args, **kwargs):
for k in kwargs:
args_tmp[k] = kwargs[k]

return str(args_tmp)
return str(args_tmp)
9 changes: 4 additions & 5 deletions plugins/PY/pinpointPy/Fastapi/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,15 @@
# limitations under the License. -
# ------------------------------------------------------------------------------


from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware
from pinpointPy.Fastapi.FastAPIRequestPlugin import FastAPIRequestPlugin


class PinPointMiddleWare(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
# todo create Root traceId
plugin = FastAPIRequestPlugin("")
traceId,_,_= plugin.onBefore(0,request)
traceId, _, _ = plugin.onBefore(0, request)
response = await call_next(request)
plugin.onEnd(traceId,response)
return response
plugin.onEnd(traceId, response)
return response
41 changes: 41 additions & 0 deletions plugins/PY/pinpointPy/Fastapi/test_fastapi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import unittest


from pinpointPy.Fastapi import PinPointMiddleWare
from pinpointPy import set_agent

from starlette_context.middleware import RawContextMiddleware
from starlette.middleware import Middleware

from fastapi import FastAPI
from fastapi.testclient import TestClient
from fastapi import Request


class Test_UT(unittest.TestCase):

def setUp(self) -> None:
middlewares = [
Middleware(
RawContextMiddleware
),
Middleware(PinPointMiddleWare)
]
app = FastAPI(title='pinpointpy test', middleware=middlewares)
set_agent("cd.dev.test.py", "cd.dev.test.py",
'tcp:dev-collector:9999', -1)

@app.get("/cluster/{name}")
async def read_main(name, request: Request):
return {"msg": f"Hello World,{name}"}

self.app = app
self.client = TestClient(app)

def test_request_example(self):
response = self.client.get("/cluster/abc")
assert "ut" in response.headers


if __name__ == '__main__':
unittest.main()
4 changes: 2 additions & 2 deletions plugins/PY/pinpointPy/Flask/test_flask.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import unittest
from flask import Flask, request
from flask import Flask
from pinpointPy.Flask.PinPointMiddleWare import PinPointMiddleWare
from pinpointPy import set_agent, monkey_patch_for_pinpoint
from pinpointPy import set_agent


class Test_Flask(unittest.TestCase):
Expand Down
Loading

0 comments on commit fc55153

Please sign in to comment.