From c963c6f56e5b9643652136ccf8f4551f96ae49cd Mon Sep 17 00:00:00 2001 From: "Tahir H. Butt" Date: Tue, 3 Nov 2020 16:45:47 -0500 Subject: [PATCH] feat(dbapi): support Connection context management usage (#1762) * fix(psycopg2): add test for connection used with contextmanager * dbapi: handle Connection __enter__ * psycopg2 < 2.5 connection doesn't support context management * add dbapi to wordlist * add tests for mysqldb * add dbapi tests * Update ddtrace/contrib/dbapi/__init__.py Co-authored-by: Julien Danjou * add some doc strings * add test cases for other dbapi libraries * add test for no context manager * update comment Co-authored-by: Kyle Verhoog Co-authored-by: Julien Danjou Co-authored-by: Kyle Verhoog --- ddtrace/contrib/dbapi/__init__.py | 45 ++- docs/spelling_wordlist.txt | 100 ++++++ .../dbapi-ctx-manager-008915d487d9f50d.yaml | 4 + tests/contrib/dbapi/test_dbapi.py | 178 ++++++++++ tests/contrib/mysqldb/test_mysqldb.py | 32 +- tests/contrib/psycopg/test_psycopg.py | 7 + tests/contrib/pymysql/test_pymysql.py | 10 + tests/contrib/pyodbc/test_pyodbc.py | 312 ++++++++++++++++++ tests/contrib/sqlite3/test_sqlite3.py | 9 + 9 files changed, 695 insertions(+), 2 deletions(-) create mode 100644 docs/spelling_wordlist.txt create mode 100644 releasenotes/notes/dbapi-ctx-manager-008915d487d9f50d.yaml create mode 100644 tests/contrib/pyodbc/test_pyodbc.py diff --git a/ddtrace/contrib/dbapi/__init__.py b/ddtrace/contrib/dbapi/__init__.py index 9f8864af90a..6d80ccf2096 100644 --- a/ddtrace/contrib/dbapi/__init__.py +++ b/ddtrace/contrib/dbapi/__init__.py @@ -9,7 +9,7 @@ from ...settings import config from ...utils.formats import asbool, get_env from ...vendor import wrapt -from ..trace_utils import ext_service +from ..trace_utils import ext_service, iswrapped log = get_logger(__name__) @@ -178,6 +178,49 @@ def __init__(self, conn, pin=None, cfg=None, cursor_cls=None): self._self_cursor_cls = cursor_cls self._self_config = cfg or config.dbapi2 + def __enter__(self): + """Context management is not defined by the dbapi spec. + + This means unfortunately that the database clients each define their own + implementations. + + The ones we know about are: + + - mysqlclient<2.0 which returns a cursor instance. >=2.0 returns a + connection instance. + - psycopg returns a connection. + - pyodbc returns a connection. + - pymysql doesn't implement it. + - sqlite3 returns the connection. + """ + r = self.__wrapped__.__enter__() + + if hasattr(r, "cursor"): + # r is Connection-like. + if r is self.__wrapped__: + # Return the reference to this proxy object. Returning r would + # return the untraced reference. + return self + else: + # r is a different connection object. + # This should not happen in practice but play it safe so that + # the original functionality is maintained. + return r + elif hasattr(r, "execute"): + # r is Cursor-like. + if iswrapped(r): + return r + else: + pin = Pin.get_from(self) + cfg = _get_config(self._self_config) + if not pin: + return r + return self._self_cursor_cls(r, pin, cfg) + else: + # Otherwise r is some other object, so maintain the functionality + # of the original. + return r + def _trace_method(self, method, name, extra_tags, *args, **kwargs): pin = Pin.get_from(self) if not pin or not pin.enabled(): diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt new file mode 100644 index 00000000000..a43015dd4df --- /dev/null +++ b/docs/spelling_wordlist.txt @@ -0,0 +1,100 @@ +CPython +INfo +MySQL +OpenTracing +aiobotocore +aiohttp +aiopg +algolia +algoliasearch +analytics +api +app +asgi +autodetected +autopatching +backend +bikeshedding +boto +botocore +config +coroutine +coroutines +datadog +datadoghq +datastore +dbapi +ddtrace +django +dogstatsd +elasticsearch +enqueue +entrypoint +entrypoints +gRPC +gevent +greenlet +greenlets +grpc +hostname +http +httplib +https +iPython +integration +integrations +jinja +kombu +kubernetes +kwarg +lifecycle +mako +memcached +metadata +microservices +middleware +mongoengine +mysql +mysqlclient +mysqldb +namespace +opentracer +opentracing +plugin +posix +postgres +prepended +profiler +psycopg +py +pylibmc +pymemcache +pymongo +pymysql +pynamodb +pyodbc +quickstart +redis +rediscluster +renderers +repo +runnable +runtime +sanic +sqlalchemy +sqlite +starlette +stringable +subdomains +submodules +timestamp +tweens +uWSGI +unix +unregister +url +urls +username +uvicorn +vertica +whitelist diff --git a/releasenotes/notes/dbapi-ctx-manager-008915d487d9f50d.yaml b/releasenotes/notes/dbapi-ctx-manager-008915d487d9f50d.yaml new file mode 100644 index 00000000000..58c26348ed4 --- /dev/null +++ b/releasenotes/notes/dbapi-ctx-manager-008915d487d9f50d.yaml @@ -0,0 +1,4 @@ +--- +fixes: + - | + dbapi: add support for connection context manager usage diff --git a/tests/contrib/dbapi/test_dbapi.py b/tests/contrib/dbapi/test_dbapi.py index 6ee5198c03a..a3a5ac61e71 100644 --- a/tests/contrib/dbapi/test_dbapi.py +++ b/tests/contrib/dbapi/test_dbapi.py @@ -1,5 +1,7 @@ import mock +import pytest + from ddtrace import Pin from ddtrace.constants import ANALYTICS_SAMPLE_RATE_KEY from ddtrace.contrib.dbapi import FetchTracedCursor, TracedCursor, TracedConnection @@ -589,3 +591,179 @@ def test_connection_analytics_with_rate(self): traced_connection.commit() span = tracer.writer.pop()[0] self.assertIsNone(span.get_metric(ANALYTICS_SAMPLE_RATE_KEY)) + + def test_connection_context_manager(self): + + class Cursor(object): + rowcount = 0 + + def execute(self, *args, **kwargs): + pass + + def fetchall(self, *args, **kwargs): + pass + + def __enter__(self): + return self + + def __exit__(self, *exc): + return False + + def commit(self, *args, **kwargs): + pass + + # When a connection is returned from a context manager the object proxy + # should be returned so that tracing works. + + class ConnectionConnection(object): + def __enter__(self): + return self + + def __exit__(self, *exc): + return False + + def cursor(self): + return Cursor() + + def commit(self): + pass + + pin = Pin("pin", tracer=self.tracer) + conn = TracedConnection(ConnectionConnection(), pin) + with conn as conn2: + conn2.commit() + spans = self.tracer.writer.pop() + assert len(spans) == 1 + + with conn as conn2: + with conn2.cursor() as cursor: + cursor.execute("query") + cursor.fetchall() + + spans = self.tracer.writer.pop() + assert len(spans) == 1 + + # If a cursor is returned from the context manager + # then it should be instrumented. + + class ConnectionCursor(object): + def __enter__(self): + return Cursor() + + def __exit__(self, *exc): + return False + + def commit(self): + pass + + with TracedConnection(ConnectionCursor(), pin) as cursor: + cursor.execute("query") + cursor.fetchall() + spans = self.tracer.writer.pop() + assert len(spans) == 1 + + # If a traced cursor is returned then it should not + # be double instrumented. + + class ConnectionTracedCursor(object): + def __enter__(self): + return self.cursor() + + def __exit__(self, *exc): + return False + + def cursor(self): + return TracedCursor(Cursor(), pin, {}) + + def commit(self): + pass + + with TracedConnection(ConnectionTracedCursor(), pin) as cursor: + cursor.execute("query") + cursor.fetchall() + spans = self.tracer.writer.pop() + assert len(spans) == 1 + + # Check when a different connection object is returned + # from a connection context manager. + # No traces should be produced. + + other_conn = ConnectionConnection() + + class ConnectionDifferentConnection(object): + def __enter__(self): + return other_conn + + def __exit__(self, *exc): + return False + + def cursor(self): + return Cursor() + + def commit(self): + pass + + conn = TracedConnection(ConnectionDifferentConnection(), pin) + with conn as conn2: + conn2.commit() + spans = self.tracer.writer.pop() + assert len(spans) == 0 + + with conn as conn2: + with conn2.cursor() as cursor: + cursor.execute("query") + cursor.fetchall() + + spans = self.tracer.writer.pop() + assert len(spans) == 0 + + # When some unexpected value is returned from the context manager + # it should be handled gracefully. + + class ConnectionUnknown(object): + def __enter__(self): + return 123456 + + def __exit__(self, *exc): + return False + + def cursor(self): + return Cursor() + + def commit(self): + pass + + conn = TracedConnection(ConnectionDifferentConnection(), pin) + with conn as conn2: + conn2.commit() + spans = self.tracer.writer.pop() + assert len(spans) == 0 + + with conn as conn2: + with conn2.cursor() as cursor: + cursor.execute("query") + cursor.fetchall() + + spans = self.tracer.writer.pop() + assert len(spans) == 0 + + # Errors should be the same when no context management is defined. + + class ConnectionNoCtx(object): + def cursor(self): + return Cursor() + + def commit(self): + pass + + conn = TracedConnection(ConnectionNoCtx(), pin) + with pytest.raises(AttributeError): + with conn: + pass + + with pytest.raises(AttributeError): + with conn as conn2: + pass + + spans = self.tracer.writer.pop() + assert len(spans) == 0 diff --git a/tests/contrib/mysqldb/test_mysqldb.py b/tests/contrib/mysqldb/test_mysqldb.py index c24afaba645..0fe9202b5ce 100644 --- a/tests/contrib/mysqldb/test_mysqldb.py +++ b/tests/contrib/mysqldb/test_mysqldb.py @@ -1,3 +1,4 @@ +import pytest import MySQLdb from ddtrace import Pin @@ -25,7 +26,7 @@ def tearDown(self): if self.conn: try: self.conn.ping() - except MySQLdb.InterfaceError: + except (MySQLdb.InterfaceError, MySQLdb.OperationalError): pass else: self.conn.close() @@ -436,6 +437,35 @@ def test_user_specified_service(self): assert spans[0].service != "mysvc" + @pytest.mark.skipif((1, 4) < MySQLdb.version_info < (2, 0), reason="context manager interface not supported") + def test_contextmanager_connection(self): + conn, tracer = self._get_conn_tracer() + writer = tracer.writer + with conn as c: + if MySQLdb.version_info < (2, 0): + cursor = c + else: + cursor = c.cursor() + rowcount = cursor.execute('SELECT 1') + assert rowcount == 1 + rows = cursor.fetchall() + assert len(rows) == 1 + spans = writer.pop() + assert len(spans) == 1 + + span = spans[0] + assert_is_measured(span) + assert span.service == "mysql" + assert span.name == 'mysql.query' + assert span.span_type == 'sql' + assert span.error == 0 + assert span.get_metric('out.port') == 3306 + assert_dict_issuperset(span.meta, { + 'out.host': u'127.0.0.1', + 'db.name': u'test', + 'db.user': u'test', + }) + class TestMysqlPatch(MySQLCore, TracerTestCase): """Ensures MysqlDB is properly patched""" diff --git a/tests/contrib/psycopg/test_psycopg.py b/tests/contrib/psycopg/test_psycopg.py index 5d27542be9e..6b493f4d3c2 100644 --- a/tests/contrib/psycopg/test_psycopg.py +++ b/tests/contrib/psycopg/test_psycopg.py @@ -352,6 +352,13 @@ def test_user_specified_service(self): self.assertEqual(len(spans), 1) assert spans[0].service == "mysvc" + @skipIf(PSYCOPG2_VERSION < (2, 5), "Connection context managers not defined in <2.5.") + def test_contextmanager_connection(self): + service = "fo" + with self._get_conn(service=service) as conn: + conn.cursor().execute("""select 'blah'""") + self.assert_structure(dict(name='postgres.query', service=service)) + def test_backwards_compatibilty_v3(): tracer = DummyTracer() diff --git a/tests/contrib/pymysql/test_pymysql.py b/tests/contrib/pymysql/test_pymysql.py index 4f9a7cac0d7..08bc6c25846 100644 --- a/tests/contrib/pymysql/test_pymysql.py +++ b/tests/contrib/pymysql/test_pymysql.py @@ -443,6 +443,16 @@ def test_user_pin_override(self): span = spans[0] assert span.service == "pin-svc" + def test_context_manager(self): + conn, tracer = self._get_conn_tracer() + # connection doesn't support context manager usage + with conn.cursor() as cursor: + cursor.execute("SELECT 1") + rows = cursor.fetchall() + assert len(rows) == 1 + spans = tracer.writer.pop() + assert len(spans) == 1 + @TracerTestCase.run_in_subprocess(env_overrides=dict(DD_PYMYSQL_SERVICE="mysvc")) def test_user_specified_service_integration(self): conn, tracer = self._get_conn_tracer() diff --git a/tests/contrib/pyodbc/test_pyodbc.py b/tests/contrib/pyodbc/test_pyodbc.py new file mode 100644 index 00000000000..2f11ce4d6a2 --- /dev/null +++ b/tests/contrib/pyodbc/test_pyodbc.py @@ -0,0 +1,312 @@ +# 3p +import pyodbc + +# project +from ddtrace import Pin +from ddtrace.constants import ANALYTICS_SAMPLE_RATE_KEY +from ddtrace.contrib.pyodbc.patch import patch, unpatch + +# testing +from ... import TracerTestCase, assert_is_measured + + +PYODBC_CONNECT_DSN = "driver=SQLite3;database=:memory:;" + + +class PyODBCTest(object): + """pyodbc test case reuses the connection across tests""" + + conn = None + + def setUp(self): + super(PyODBCTest, self).setUp() + patch() + + def tearDown(self): + super(PyODBCTest, self).tearDown() + if self.conn: + try: + self.conn.close() + except pyodbc.ProgrammingError: + pass + unpatch() + + def _get_conn_tracer(self): + pass + + def test_simple_query(self): + conn, tracer = self._get_conn_tracer() + writer = tracer.writer + cursor = conn.cursor() + cursor.execute("SELECT 1") + rows = cursor.fetchall() + assert len(rows) == 1 + spans = writer.pop() + assert len(spans) == 1 + + span = spans[0] + assert_is_measured(span) + assert span.service == "pyodbc" + assert span.name == "pyodbc.query" + assert span.span_type == "sql" + assert span.error == 0 + + def test_simple_query_fetchall(self): + with self.override_config("dbapi2", dict(trace_fetch_methods=True)): + conn, tracer = self._get_conn_tracer() + writer = tracer.writer + cursor = conn.cursor() + cursor.execute("SELECT 1") + rows = cursor.fetchall() + assert len(rows) == 1 + spans = writer.pop() + assert len(spans) == 1 + + span = spans[0] + assert_is_measured(span) + assert span.service == "pyodbc" + assert span.name == "pyodbc.query" + assert span.span_type == "sql" + assert span.error == 0 + fetch_span = spans[0] + assert fetch_span.name == "pyodbc.query" + + def test_query_with_several_rows(self): + conn, tracer = self._get_conn_tracer() + writer = tracer.writer + cursor = conn.cursor() + query = "SELECT n FROM (SELECT 42 n UNION SELECT 421 UNION SELECT 4210) m" + cursor.execute(query) + rows = cursor.fetchall() + assert len(rows) == 3 + spans = writer.pop() + assert len(spans) == 1 + self.assertEqual(spans[0].name, "pyodbc.query") + + def test_query_with_several_rows_fetchall(self): + with self.override_config("dbapi2", dict(trace_fetch_methods=True)): + conn, tracer = self._get_conn_tracer() + writer = tracer.writer + cursor = conn.cursor() + query = "SELECT n FROM (SELECT 42 n UNION SELECT 421 UNION SELECT 4210) m" + cursor.execute(query) + rows = cursor.fetchall() + assert len(rows) == 3 + spans = writer.pop() + assert len(spans) == 1 + + fetch_span = spans[0] + assert fetch_span.name == "pyodbc.query" + + def test_query_many(self): + # tests that the executemany method is correctly wrapped. + conn, tracer = self._get_conn_tracer() + writer = tracer.writer + tracer.enabled = False + cursor = conn.cursor() + + tracer.enabled = True + cursor.execute( + """ + create table if not exists dummy ( + dummy_key VARCHAR(32) PRIMARY KEY, + dummy_value TEXT NOT NULL)""" + ) + + stmt = "INSERT INTO dummy (dummy_key, dummy_value) VALUES (?, ?), (?, ?)" + data = ["foo", "this is foo", "bar", "this is bar"] + cursor.execute(stmt, data) + + query = "SELECT dummy_key, dummy_value FROM dummy ORDER BY dummy_key" + cursor.execute(query) + rows = cursor.fetchall() + assert len(rows) == 2 + assert rows[0][0] == "bar" + assert rows[0][1] == "this is bar" + assert rows[1][0] == "foo" + assert rows[1][1] == "this is foo" + + spans = writer.pop() + assert len(spans) == 3 + cursor.execute("drop table if exists dummy") + + def test_query_many_fetchall(self): + with self.override_config("dbapi2", dict(trace_fetch_methods=True)): + # tests that the executemany method is correctly wrapped. + conn, tracer = self._get_conn_tracer() + writer = tracer.writer + tracer.enabled = False + cursor = conn.cursor() + + tracer.enabled = True + cursor.execute( + """ + create table if not exists dummy ( + dummy_key VARCHAR(32) PRIMARY KEY, + dummy_value TEXT NOT NULL)""" + ) + + stmt = "INSERT INTO dummy (dummy_key, dummy_value) VALUES (?, ?)" + data = [("foo", "this is foo"), ("bar", "this is bar")] + cursor.executemany(stmt, data) + query = "SELECT dummy_key, dummy_value FROM dummy ORDER BY dummy_key" + cursor.execute(query) + rows = cursor.fetchall() + assert len(rows) == 2 + assert rows[0][0] == "bar" + assert rows[0][1] == "this is bar" + assert rows[1][0] == "foo" + assert rows[1][1] == "this is foo" + + spans = writer.pop() + assert len(spans) == 3 + cursor.execute("drop table if exists dummy") + + fetch_span = spans[2] + assert fetch_span.name == "pyodbc.query" + + def test_commit(self): + conn, tracer = self._get_conn_tracer() + writer = tracer.writer + conn.commit() + spans = writer.pop() + assert len(spans) == 1 + span = spans[0] + assert span.service == "pyodbc" + assert span.name == "pyodbc.connection.commit" + + def test_rollback(self): + conn, tracer = self._get_conn_tracer() + writer = tracer.writer + conn.rollback() + spans = writer.pop() + assert len(spans) == 1 + span = spans[0] + assert span.service == "pyodbc" + assert span.name == "pyodbc.connection.rollback" + + def test_analytics_default(self): + conn, tracer = self._get_conn_tracer() + writer = tracer.writer + cursor = conn.cursor() + cursor.execute("SELECT 1") + rows = cursor.fetchall() + assert len(rows) == 1 + spans = writer.pop() + + self.assertEqual(len(spans), 1) + span = spans[0] + self.assertIsNone(span.get_metric(ANALYTICS_SAMPLE_RATE_KEY)) + + def test_analytics_with_rate(self): + with self.override_config("dbapi2", dict(analytics_enabled=True, analytics_sample_rate=0.5)): + conn, tracer = self._get_conn_tracer() + writer = tracer.writer + cursor = conn.cursor() + cursor.execute("SELECT 1") + rows = cursor.fetchall() + assert len(rows) == 1 + spans = writer.pop() + + self.assertEqual(len(spans), 1) + span = spans[0] + self.assertEqual(span.get_metric(ANALYTICS_SAMPLE_RATE_KEY), 0.5) + + def test_analytics_without_rate(self): + with self.override_config("dbapi2", dict(analytics_enabled=True)): + conn, tracer = self._get_conn_tracer() + writer = tracer.writer + cursor = conn.cursor() + cursor.execute("SELECT 1") + rows = cursor.fetchall() + assert len(rows) == 1 + spans = writer.pop() + + self.assertEqual(len(spans), 1) + span = spans[0] + self.assertEqual(span.get_metric(ANALYTICS_SAMPLE_RATE_KEY), 1.0) + + def test_context_manager(self): + conn, tracer = self._get_conn_tracer() + with conn as conn2: + with conn2.cursor() as cursor: + cursor.execute("SELECT 1") + rows = cursor.fetchall() + assert len(rows) == 1 + spans = tracer.writer.pop() + assert len(spans) == 1 + + +class TestPyODBCPatch(PyODBCTest, TracerTestCase): + def _get_conn_tracer(self): + if not self.conn: + self.conn = pyodbc.connect(PYODBC_CONNECT_DSN) + # Ensure that the default pin is there, with its default value + pin = Pin.get_from(self.conn) + assert pin + # Customize the service + # we have to apply it on the existing one since new one won't inherit `app` + pin.clone(tracer=self.tracer).onto(self.conn) + + return self.conn, self.tracer + + def test_patch_unpatch(self): + unpatch() + # assert we start unpatched + conn = pyodbc.connect(PYODBC_CONNECT_DSN) + assert not Pin.get_from(conn) + conn.close() + + patch() + try: + writer = self.tracer.writer + conn = pyodbc.connect(PYODBC_CONNECT_DSN) + pin = Pin.get_from(conn) + assert pin + pin.clone(tracer=self.tracer).onto(conn) + + cursor = conn.cursor() + cursor.execute("SELECT 1") + rows = cursor.fetchall() + assert len(rows) == 1 + spans = writer.pop() + assert len(spans) == 1 + + span = spans[0] + assert span.service == "pyodbc" + assert span.name == "pyodbc.query" + assert span.span_type == "sql" + assert span.error == 0 + finally: + unpatch() + + # assert we finish unpatched + conn = pyodbc.connect(PYODBC_CONNECT_DSN) + assert not Pin.get_from(conn) + conn.close() + + patch() + + def test_user_pin_override(self): + conn, tracer = self._get_conn_tracer() + pin = Pin.get_from(conn) + pin.clone(service="pin-svc", tracer=self.tracer).onto(conn) + cursor = conn.cursor() + cursor.execute("SELECT 1") + rows = cursor.fetchall() + assert len(rows) == 1 + spans = tracer.writer.pop() + assert len(spans) == 1 + + span = spans[0] + assert span.service == "pin-svc" + + @TracerTestCase.run_in_subprocess(env_overrides=dict(DD_PYODBC_SERVICE="my-pyodbc-service")) + def test_user_specified_service_integration(self): + conn, tracer = self._get_conn_tracer() + writer = tracer.writer + conn.rollback() + spans = writer.pop() + assert len(spans) == 1 + span = spans[0] + assert span.service == "my-pyodbc-service" diff --git a/tests/contrib/sqlite3/test_sqlite3.py b/tests/contrib/sqlite3/test_sqlite3.py index 6a051f2cb8e..f77e3e77b9f 100644 --- a/tests/contrib/sqlite3/test_sqlite3.py +++ b/tests/contrib/sqlite3/test_sqlite3.py @@ -364,3 +364,12 @@ def test_user_specified_service(self): self.assertEqual(len(spans), 1) span = spans[0] assert span.service == "my-svc" + + def test_context_manager(self): + conn = self._given_a_traced_connection(self.tracer) + with conn as conn2: + cursor = conn2.execute("select * from sqlite_master") + cursor.fetchall() + cursor.fetchall() + spans = self.get_spans() + assert len(spans) == 1