-
Notifications
You must be signed in to change notification settings - Fork 110
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SNOW-1418523 Make Session thread safe #2312
Changes from 17 commits
fb4ecf6
f720701
eca13dc
96949be
5f140ab
0624824
42d6e19
801ad6e
c7fa3ae
5672a1d
bd0528d
39a07d4
4d4e257
66374ee
8e57d95
822e3f8
860d947
805ca91
1583ad9
9aeb1f1
2e78e4a
2d2ff8a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
name: Threadsafe Check | ||
|
||
on: | ||
pull_request: | ||
types: [opened, synchronize, labeled, unlabeled, edited] | ||
|
||
jobs: | ||
check_threadsafety: | ||
runs-on: ubuntu-latest | ||
steps: | ||
- name: Checkout | ||
uses: actions/checkout@v4 | ||
with: | ||
fetch-depth: 0 | ||
|
||
- name: Check for modified files | ||
id: changed_files | ||
run: | | ||
echo "changed_files=$(git diff --name-only --diff-filter=ACMRT ${{ github.event.pull_request.base.sha }} ${{ github.sha }} | xargs)" >> $GITHUB_OUTPUT | ||
|
||
- name: Verify threadsafety acknowledgement | ||
run: | | ||
CHANGED_FILES="${{ steps.changed_files.outputs.changed_files }}" | ||
# Check if changed files are in snowpark/_internal, snowpark/mock, or snowpark/*.py. We exclude snowpark/modin in this check. | ||
if echo "$CHANGED_FILES" | grep -qE '(src/snowflake/snowpark/_internal|src/snowflake/snowpark/mock|src/snowflake/snowpark/[^/]+\.py)'; then | ||
echo "Checking PR description for thread-safety acknowledgment..." | ||
if [[ "${{ github.event.pull_request.body }}" != *"[x] I acknowledge that I have ensured my changes to be thread-safe"* ]]; then | ||
echo "Thread-safety acknowledgment not found in PR description." | ||
echo "Please acknowledge the threadsafety implications of your changes by adding '[x] I acknowledge that I have ensured my changes to be thread-safe' to the PR description." | ||
exit 1 | ||
else | ||
echo "Thread-safety acknowledgment found in PR description." | ||
fi | ||
else | ||
echo "No critical files modified; skipping threadsafety check." | ||
fi |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,8 @@ | |
|
||
#### New Features | ||
|
||
- Updated `Session` class to be thread-safe. This allows concurrent query submission, dataframe operations, UDF and store procedure registration, and concurret file uploads. | ||
- One limitation is that updating `Session` configurations inside multiple-threads may cause other active thread to break. Please update `Session` configurations before starting multiple threads. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This details isn't needed in the changelog. It should be somewhere else in the doc. |
||
- Added the following new functions in `snowflake.snowpark.functions`: | ||
- `make_interval` | ||
- Added support for using Snowflake Interval constants with `Window.range_between()` when the order by column is TIMESTAMP or DATE type. | ||
|
@@ -180,7 +182,7 @@ This is a re-release of 1.22.0. Please refer to the 1.22.0 release notes for det | |
|
||
- Improve concat, join performance when operations are performed on series coming from the same dataframe by avoiding unnecessary joins. | ||
- Refactored `quoted_identifier_to_snowflake_type` to avoid making metadata queries if the types have been cached locally. | ||
- Improved `pd.to_datetime` to handle all local input cases. | ||
- Improved `pd.to_datetime` to handle all local input cases. | ||
- Create a lazy index from another lazy index without pulling data to client. | ||
- Raised `NotImplementedError` for Index bitwise operators. | ||
- Display a more clear error message when `Index.names` is set to a non-like-like object. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -155,6 +155,8 @@ def __init__( | |
options: Dict[str, Union[int, str]], | ||
conn: Optional[SnowflakeConnection] = None, | ||
) -> None: | ||
self._lock = threading.RLock() | ||
self._thread_store = threading.local() | ||
self._lower_case_parameters = {k.lower(): v for k, v in options.items()} | ||
self._add_application_parameters() | ||
self._conn = conn if conn else connect(**self._lower_case_parameters) | ||
|
@@ -171,7 +173,6 @@ def __init__( | |
|
||
if "password" in self._lower_case_parameters: | ||
self._lower_case_parameters["password"] = None | ||
self._cursor = self._conn.cursor() | ||
self._telemetry_client = TelemetryClient(self._conn) | ||
self._query_listener: Set[QueryHistory] = set() | ||
# The session in this case refers to a Snowflake session, not a | ||
|
@@ -184,6 +185,15 @@ def __init__( | |
"_skip_upload_on_content_match" in signature.parameters | ||
) | ||
|
||
@property | ||
def _cursor(self) -> SnowflakeCursor: | ||
if not hasattr(self._thread_store, "cursor"): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, |
||
self._thread_store.cursor = self._conn.cursor() | ||
self._telemetry_client.send_cursor_created_telemetry( | ||
self.get_session_id(), threading.get_ident() | ||
) | ||
return self._thread_store.cursor | ||
|
||
def _add_application_parameters(self) -> None: | ||
if PARAM_APPLICATION not in self._lower_case_parameters: | ||
# Mirrored from snowflake-connector-python/src/snowflake/connector/connection.py#L295 | ||
|
@@ -211,10 +221,12 @@ def _add_application_parameters(self) -> None: | |
] = get_version() | ||
|
||
def add_query_listener(self, listener: QueryHistory) -> None: | ||
self._query_listener.add(listener) | ||
with self._lock: | ||
self._query_listener.add(listener) | ||
|
||
def remove_query_listener(self, listener: QueryHistory) -> None: | ||
self._query_listener.remove(listener) | ||
with self._lock: | ||
self._query_listener.remove(listener) | ||
|
||
def close(self) -> None: | ||
if self._conn: | ||
|
@@ -253,17 +265,21 @@ def _run_new_describe( | |
) -> Union[List[ResultMetadata], List["ResultMetadataV2"]]: | ||
result_metadata = run_new_describe(cursor, query) | ||
|
||
for listener in filter( | ||
lambda listener: hasattr(listener, "include_describe") | ||
and listener.include_describe, | ||
self._query_listener, | ||
): | ||
query_record = QueryRecord(cursor.sfqid, query, True) | ||
if getattr(listener, "include_thread_id", False): | ||
with self._lock: | ||
for listener in filter( | ||
lambda listener: hasattr(listener, "include_describe") | ||
and listener.include_describe, | ||
self._query_listener, | ||
): | ||
thread_id = ( | ||
threading.get_ident() | ||
if getattr(listener, "include_thread_id", False) | ||
else None | ||
) | ||
query_record = QueryRecord( | ||
cursor.sfqid, query, True, threading.get_ident() | ||
cursor.sfqid, query, True, thread_id=thread_id | ||
) | ||
listener._add_query(query_record) | ||
listener._add_query(query_record) | ||
|
||
return result_metadata | ||
|
||
|
@@ -380,17 +396,18 @@ def upload_stream( | |
raise ex | ||
|
||
def notify_query_listeners(self, query_record: QueryRecord) -> None: | ||
for listener in self._query_listener: | ||
if getattr(listener, "include_thread_id", False): | ||
new_record = QueryRecord( | ||
query_record.query_id, | ||
query_record.sql_text, | ||
query_record.is_describe, | ||
thread_id=threading.get_ident(), | ||
) | ||
listener._add_query(new_record) | ||
else: | ||
listener._add_query(query_record) | ||
with self._lock: | ||
for listener in self._query_listener: | ||
if getattr(listener, "include_thread_id", False): | ||
new_record = QueryRecord( | ||
query_record.query_id, | ||
query_record.sql_text, | ||
query_record.is_describe, | ||
thread_id=threading.get_ident(), | ||
) | ||
listener._add_query(new_record) | ||
else: | ||
listener._add_query(query_record) | ||
|
||
def execute_and_notify_query_listener( | ||
self, query: str, **kwargs: Any | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,7 @@ | |
# | ||
|
||
import functools | ||
import threading | ||
from enum import Enum, unique | ||
from typing import Any, Dict, List, Optional | ||
|
||
|
@@ -38,6 +39,7 @@ class TelemetryField(Enum): | |
TYPE_PERFORMANCE_DATA = "snowpark_performance_data" | ||
TYPE_FUNCTION_USAGE = "snowpark_function_usage" | ||
TYPE_SESSION_CREATED = "snowpark_session_created" | ||
TYPE_CURSOR_CREATED = "snowpark_cursor_created" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is the purpose to know whether multiple cursor is used? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, this addition is part of collecting telemtry. It will track if a cursor is used in a new thread. |
||
TYPE_SQL_SIMPLIFIER_ENABLED = "snowpark_sql_simplifier_enabled" | ||
TYPE_CTE_OPTIMIZATION_ENABLED = "snowpark_cte_optimization_enabled" | ||
# telemetry for optimization that eliminates the extra cast expression generated for expressions | ||
|
@@ -90,6 +92,8 @@ class TelemetryField(Enum): | |
TEMP_TABLE_CLEANUP_ABNORMAL_EXCEPTION_MESSAGE = ( | ||
"temp_table_cleanup_abnormal_exception_message" | ||
) | ||
# multi-threading | ||
THREAD_IDENTIFIER = "thread_ident" | ||
|
||
|
||
# These DataFrame APIs call other DataFrame APIs | ||
|
@@ -196,6 +200,7 @@ def wrap(*args, **kwargs): | |
key.value: value | ||
for key, value in plan.cumulative_node_complexity.items() | ||
} | ||
api_calls[0][TelemetryField.THREAD_IDENTIFIER.value] = threading.get_ident() | ||
except Exception: | ||
pass | ||
args[0]._session._conn._telemetry_client.send_function_usage_telemetry( | ||
|
@@ -343,6 +348,7 @@ def send_upload_file_perf_telemetry( | |
TelemetryField.KEY_CATEGORY.value: TelemetryField.PERF_CAT_UPLOAD_FILE.value, | ||
TelemetryField.KEY_FUNC_NAME.value: func_name, | ||
TelemetryField.KEY_DURATION.value: duration, | ||
TelemetryField.THREAD_IDENTIFIER.value: threading.get_ident(), | ||
}, | ||
} | ||
self.send(message) | ||
|
@@ -537,3 +543,15 @@ def send_large_query_breakdown_update_complexity_bounds( | |
}, | ||
} | ||
self.send(message) | ||
|
||
def send_cursor_created_telemetry(self, session_id: int, thread_id: int): | ||
message = { | ||
**self._create_basic_telemetry_data( | ||
TelemetryField.TYPE_CURSOR_CREATED.value | ||
), | ||
TelemetryField.KEY_DATA.value: { | ||
TelemetryField.SESSION_ID.value: session_id, | ||
TelemetryField.THREAD_IDENTIFIER.value: thread_id, | ||
}, | ||
} | ||
self.send(message) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Typo: concurret