Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1418523 Make Session thread safe #2312

Merged
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
fb4ecf6
add locks
sfc-gh-aalam Sep 11, 2024
f720701
Merge branch 'main' into aalam-SNOW-1418523-make-internal-session-var…
sfc-gh-aalam Sep 12, 2024
eca13dc
Merge branch 'main' into aalam-SNOW-1418523-make-internal-session-var…
sfc-gh-aalam Sep 17, 2024
96949be
Merge branch 'main' into aalam-SNOW-1418523-make-internal-session-var…
sfc-gh-aalam Sep 18, 2024
5f140ab
SNOW-1418523 make analyzer server connection thread safe (#2282)
sfc-gh-aalam Sep 25, 2024
0624824
SNOW-1418523: concurrent file operations (#2288)
sfc-gh-aalam Sep 25, 2024
42d6e19
SNOW-1418523: make udf and sproc registration thread safe (#2289)
sfc-gh-aalam Sep 25, 2024
801ad6e
merge with main
sfc-gh-aalam Oct 2, 2024
c7fa3ae
Merge branch 'main' into aalam-SNOW-1418523-make-internal-session-var…
sfc-gh-aalam Oct 3, 2024
5672a1d
SNOW-1663726 make session config updates thread safe (#2302)
sfc-gh-aalam Oct 4, 2024
bd0528d
SNOW-1663726 make temp table cleaner thread safe (#2309)
sfc-gh-aalam Oct 4, 2024
39a07d4
SNOW-1642189: collect telemetry about concurrency usage (#2316)
sfc-gh-aalam Oct 4, 2024
4d4e257
SNOW-1546090 add merge gate for future thread safe updates (#2323)
sfc-gh-aalam Oct 4, 2024
66374ee
add plan-builder that was accidentally removed
sfc-gh-aalam Oct 4, 2024
8e57d95
changelog updates
sfc-gh-aalam Oct 4, 2024
822e3f8
create hyperlink for doc
sfc-gh-aalam Oct 4, 2024
860d947
Merge branch 'main' into aalam-SNOW-1418523-make-internal-session-var…
sfc-gh-aalam Oct 8, 2024
805ca91
address comments
sfc-gh-aalam Oct 14, 2024
1583ad9
Merge branch 'main' into aalam-SNOW-1418523-make-internal-session-var…
sfc-gh-aalam Oct 14, 2024
9aeb1f1
fix changelog
sfc-gh-aalam Oct 14, 2024
2e78e4a
fix changelog
sfc-gh-aalam Oct 14, 2024
2d2ff8a
address feedback
sfc-gh-aalam Oct 14, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/pull_request_template.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ Please answer these questions before creating your pull request. Thanks!

<!---
In this section, please add a Snowflake Jira issue number.

Note that if a corresponding GitHub issue exists, you should still include
the Snowflake Jira issue number. For example, for GitHub issue
https://github.com/snowflakedb/snowpark-python/issues/1400, you should
Expand All @@ -24,6 +24,7 @@ Please answer these questions before creating your pull request. Thanks!
- [ ] I am adding new credentials
- [ ] I am adding a new dependency
- [ ] If this is a new feature/behavior, I'm adding the Local Testing parity changes.
- [ ] I acknowledge that I have ensured my changes to be thread-safe. Follow the link for more information: [Thread-safe Developer Guidelines](https://docs.google.com/document/d/162d_i4zZ2AfcGRXojj0jByt8EUq-DrSHPPnTa4QvwbA/edit#bookmark=id.e82u4nekq80k)

3. Please describe how your code solves the related issue.

Expand Down
36 changes: 36 additions & 0 deletions .github/workflows/threadsafe-check.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
name: Threadsafe Check

on:
pull_request:
types: [opened, synchronize, labeled, unlabeled, edited]

jobs:
check_threadsafety:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0

- name: Check for modified files
id: changed_files
run: |
echo "changed_files=$(git diff --name-only --diff-filter=ACMRT ${{ github.event.pull_request.base.sha }} ${{ github.sha }} | xargs)" >> $GITHUB_OUTPUT

- name: Verify threadsafety acknowledgement
run: |
CHANGED_FILES="${{ steps.changed_files.outputs.changed_files }}"
# Check if changed files are in snowpark/_internal, snowpark/mock, or snowpark/*.py. We exclude snowpark/modin in this check.
if echo "$CHANGED_FILES" | grep -qE '(src/snowflake/snowpark/_internal|src/snowflake/snowpark/mock|src/snowflake/snowpark/[^/]+\.py)'; then
echo "Checking PR description for thread-safety acknowledgment..."
if [[ "${{ github.event.pull_request.body }}" != *"[x] I acknowledge that I have ensured my changes to be thread-safe"* ]]; then
echo "Thread-safety acknowledgment not found in PR description."
echo "Please acknowledge the threadsafety implications of your changes by adding '[x] I acknowledge that I have ensured my changes to be thread-safe' to the PR description."
exit 1
else
echo "Thread-safety acknowledgment found in PR description."
fi
else
echo "No critical files modified; skipping threadsafety check."
fi
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

#### New Features

- Updated `Session` class to be thread-safe. This allows concurrent query submission, dataframe operations, UDF and store procedure registration, and concurret file uploads.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo: concurret

- One limitation is that updating `Session` configurations inside multiple-threads may cause other active thread to break. Please update `Session` configurations before starting multiple threads.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This details isn't needed in the changelog. It should be somewhere else in the doc.

- Added the following new functions in `snowflake.snowpark.functions`:
- `make_interval`
- Added support for using Snowflake Interval constants with `Window.range_between()` when the order by column is TIMESTAMP or DATE type.
Expand Down Expand Up @@ -180,7 +182,7 @@ This is a re-release of 1.22.0. Please refer to the 1.22.0 release notes for det

- Improve concat, join performance when operations are performed on series coming from the same dataframe by avoiding unnecessary joins.
- Refactored `quoted_identifier_to_snowflake_type` to avoid making metadata queries if the types have been cached locally.
- Improved `pd.to_datetime` to handle all local input cases.
- Improved `pd.to_datetime` to handle all local input cases.
- Create a lazy index from another lazy index without pulling data to client.
- Raised `NotImplementedError` for Index bitwise operators.
- Display a more clear error message when `Index.names` is set to a non-like-like object.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,17 +113,14 @@ def __init__(
session: Session,
query_generator: QueryGenerator,
logical_plans: List[LogicalPlan],
complexity_bounds: Tuple[int, int],
) -> None:
self.session = session
self._query_generator = query_generator
self.logical_plans = logical_plans
self._parent_map = defaultdict(set)
self.complexity_score_lower_bound = (
session.large_query_breakdown_complexity_bounds[0]
)
self.complexity_score_upper_bound = (
session.large_query_breakdown_complexity_bounds[1]
)
self.complexity_score_lower_bound = complexity_bounds[0]
self.complexity_score_upper_bound = complexity_bounds[1]

def apply(self) -> List[LogicalPlan]:
if is_active_transaction(self.session):
Expand Down
14 changes: 8 additions & 6 deletions src/snowflake/snowpark/_internal/compiler/plan_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def should_start_query_compilation(self) -> bool:

def compile(self) -> Dict[PlanQueryType, List[Query]]:
if self.should_start_query_compilation():
session = self._plan.session
# preparation for compilation
# 1. make a copy of the original plan
start_time = time.time()
Expand All @@ -92,7 +93,7 @@ def compile(self) -> Dict[PlanQueryType, List[Query]]:
# 3. apply each optimizations if needed
# CTE optimization
cte_start_time = time.time()
if self._plan.session.cte_optimization_enabled:
if session.cte_optimization_enabled:
repeated_subquery_eliminator = RepeatedSubqueryElimination(
logical_plans, query_generator
)
Expand All @@ -111,9 +112,12 @@ def compile(self) -> Dict[PlanQueryType, List[Query]]:
plot_plan_if_enabled(plan, f"cte_optimized_plan_{i}")

# Large query breakdown
if self._plan.session.large_query_breakdown_enabled:
if session.large_query_breakdown_enabled:
large_query_breakdown = LargeQueryBreakdown(
self._plan.session, query_generator, logical_plans
session,
query_generator,
logical_plans,
session.large_query_breakdown_complexity_bounds,
)
logical_plans = large_query_breakdown.apply()

Expand All @@ -132,7 +136,6 @@ def compile(self) -> Dict[PlanQueryType, List[Query]]:
cte_time = cte_end_time - cte_start_time
large_query_breakdown_time = large_query_breakdown_end_time - cte_end_time
total_time = time.time() - start_time
session = self._plan.session
summary_value = {
TelemetryField.CTE_OPTIMIZATION_ENABLED.value: session.cte_optimization_enabled,
TelemetryField.LARGE_QUERY_BREAKDOWN_ENABLED.value: session.large_query_breakdown_enabled,
Expand All @@ -155,8 +158,7 @@ def compile(self) -> Dict[PlanQueryType, List[Query]]:
return queries
else:
final_plan = self._plan
if self._plan.session.cte_optimization_enabled:
final_plan = final_plan.replace_repeated_subquery_with_cte()
final_plan = final_plan.replace_repeated_subquery_with_cte()
sfc-gh-jrose marked this conversation as resolved.
Show resolved Hide resolved
return {
PlanQueryType.QUERIES: final_plan.queries,
PlanQueryType.POST_ACTIONS: final_plan.post_actions,
Expand Down
63 changes: 40 additions & 23 deletions src/snowflake/snowpark/_internal/server_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,8 @@ def __init__(
options: Dict[str, Union[int, str]],
conn: Optional[SnowflakeConnection] = None,
) -> None:
self._lock = threading.RLock()
self._thread_store = threading.local()
self._lower_case_parameters = {k.lower(): v for k, v in options.items()}
self._add_application_parameters()
self._conn = conn if conn else connect(**self._lower_case_parameters)
Expand All @@ -171,7 +173,6 @@ def __init__(

if "password" in self._lower_case_parameters:
self._lower_case_parameters["password"] = None
self._cursor = self._conn.cursor()
self._telemetry_client = TelemetryClient(self._conn)
self._query_listener: Set[QueryHistory] = set()
# The session in this case refers to a Snowflake session, not a
Expand All @@ -184,6 +185,15 @@ def __init__(
"_skip_upload_on_content_match" in signature.parameters
)

@property
def _cursor(self) -> SnowflakeCursor:
if not hasattr(self._thread_store, "cursor"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is hasattr() itself thread-safe?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, hasattr is read-only operation on thread local object.

self._thread_store.cursor = self._conn.cursor()
self._telemetry_client.send_cursor_created_telemetry(
self.get_session_id(), threading.get_ident()
)
return self._thread_store.cursor

def _add_application_parameters(self) -> None:
if PARAM_APPLICATION not in self._lower_case_parameters:
# Mirrored from snowflake-connector-python/src/snowflake/connector/connection.py#L295
Expand Down Expand Up @@ -211,10 +221,12 @@ def _add_application_parameters(self) -> None:
] = get_version()

def add_query_listener(self, listener: QueryHistory) -> None:
self._query_listener.add(listener)
with self._lock:
self._query_listener.add(listener)

def remove_query_listener(self, listener: QueryHistory) -> None:
self._query_listener.remove(listener)
with self._lock:
self._query_listener.remove(listener)

def close(self) -> None:
if self._conn:
Expand Down Expand Up @@ -253,17 +265,21 @@ def _run_new_describe(
) -> Union[List[ResultMetadata], List["ResultMetadataV2"]]:
result_metadata = run_new_describe(cursor, query)

for listener in filter(
lambda listener: hasattr(listener, "include_describe")
and listener.include_describe,
self._query_listener,
):
query_record = QueryRecord(cursor.sfqid, query, True)
if getattr(listener, "include_thread_id", False):
with self._lock:
for listener in filter(
lambda listener: hasattr(listener, "include_describe")
and listener.include_describe,
self._query_listener,
):
thread_id = (
threading.get_ident()
if getattr(listener, "include_thread_id", False)
else None
)
query_record = QueryRecord(
cursor.sfqid, query, True, threading.get_ident()
cursor.sfqid, query, True, thread_id=thread_id
)
listener._add_query(query_record)
listener._add_query(query_record)

return result_metadata

Expand Down Expand Up @@ -380,17 +396,18 @@ def upload_stream(
raise ex

def notify_query_listeners(self, query_record: QueryRecord) -> None:
for listener in self._query_listener:
if getattr(listener, "include_thread_id", False):
new_record = QueryRecord(
query_record.query_id,
query_record.sql_text,
query_record.is_describe,
thread_id=threading.get_ident(),
)
listener._add_query(new_record)
else:
listener._add_query(query_record)
with self._lock:
for listener in self._query_listener:
if getattr(listener, "include_thread_id", False):
new_record = QueryRecord(
query_record.query_id,
query_record.sql_text,
query_record.is_describe,
thread_id=threading.get_ident(),
)
listener._add_query(new_record)
else:
listener._add_query(query_record)

def execute_and_notify_query_listener(
self, query: str, **kwargs: Any
Expand Down
18 changes: 18 additions & 0 deletions src/snowflake/snowpark/_internal/telemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#

import functools
import threading
from enum import Enum, unique
from typing import Any, Dict, List, Optional

Expand Down Expand Up @@ -38,6 +39,7 @@ class TelemetryField(Enum):
TYPE_PERFORMANCE_DATA = "snowpark_performance_data"
TYPE_FUNCTION_USAGE = "snowpark_function_usage"
TYPE_SESSION_CREATED = "snowpark_session_created"
TYPE_CURSOR_CREATED = "snowpark_cursor_created"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the purpose to know whether multiple cursor is used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this addition is part of collecting telemtry. It will track if a cursor is used in a new thread.

TYPE_SQL_SIMPLIFIER_ENABLED = "snowpark_sql_simplifier_enabled"
TYPE_CTE_OPTIMIZATION_ENABLED = "snowpark_cte_optimization_enabled"
# telemetry for optimization that eliminates the extra cast expression generated for expressions
Expand Down Expand Up @@ -90,6 +92,8 @@ class TelemetryField(Enum):
TEMP_TABLE_CLEANUP_ABNORMAL_EXCEPTION_MESSAGE = (
"temp_table_cleanup_abnormal_exception_message"
)
# multi-threading
THREAD_IDENTIFIER = "thread_ident"


# These DataFrame APIs call other DataFrame APIs
Expand Down Expand Up @@ -196,6 +200,7 @@ def wrap(*args, **kwargs):
key.value: value
for key, value in plan.cumulative_node_complexity.items()
}
api_calls[0][TelemetryField.THREAD_IDENTIFIER.value] = threading.get_ident()
except Exception:
pass
args[0]._session._conn._telemetry_client.send_function_usage_telemetry(
Expand Down Expand Up @@ -343,6 +348,7 @@ def send_upload_file_perf_telemetry(
TelemetryField.KEY_CATEGORY.value: TelemetryField.PERF_CAT_UPLOAD_FILE.value,
TelemetryField.KEY_FUNC_NAME.value: func_name,
TelemetryField.KEY_DURATION.value: duration,
TelemetryField.THREAD_IDENTIFIER.value: threading.get_ident(),
},
}
self.send(message)
Expand Down Expand Up @@ -537,3 +543,15 @@ def send_large_query_breakdown_update_complexity_bounds(
},
}
self.send(message)

def send_cursor_created_telemetry(self, session_id: int, thread_id: int):
message = {
**self._create_basic_telemetry_data(
TelemetryField.TYPE_CURSOR_CREATED.value
),
TelemetryField.KEY_DATA.value: {
TelemetryField.SESSION_ID.value: session_id,
TelemetryField.THREAD_IDENTIFIER.value: thread_id,
},
}
self.send(message)
22 changes: 15 additions & 7 deletions src/snowflake/snowpark/_internal/temp_table_auto_cleaner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved.
#
import logging
import threading
import weakref
from collections import defaultdict
from typing import TYPE_CHECKING, Dict
Expand Down Expand Up @@ -31,9 +32,12 @@ def __init__(self, session: "Session") -> None:
# to its reference count for later temp table management
# this dict will still be maintained even if the cleaner is stopped (`stop()` is called)
self.ref_count_map: Dict[str, int] = defaultdict(int)
# Lock to protect the ref_count_map
self.lock = threading.RLock()

def add(self, table: SnowflakeTable) -> None:
self.ref_count_map[table.name] += 1
with self.lock:
self.ref_count_map[table.name] += 1
# the finalizer will be triggered when it gets garbage collected
# and this table will be dropped finally
_ = weakref.finalize(table, self._delete_ref_count, table.name)
Expand All @@ -43,18 +47,20 @@ def _delete_ref_count(self, name: str) -> None: # pragma: no cover
Decrements the reference count of a temporary table,
and if the count reaches zero, puts this table in the queue for cleanup.
"""
self.ref_count_map[name] -= 1
if self.ref_count_map[name] == 0:
with self.lock:
self.ref_count_map[name] -= 1
current_ref_count = self.ref_count_map[name]
if current_ref_count == 0:
if (
self.session.auto_clean_up_temp_table_enabled
# if the session is already closed before garbage collection,
# we have no way to drop the table
and not self.session._conn.is_closed()
):
self.drop_table(name)
elif self.ref_count_map[name] < 0:
elif current_ref_count < 0:
logging.debug(
f"Unexpected reference count {self.ref_count_map[name]} for table {name}"
f"Unexpected reference count {current_ref_count} for table {name}"
)

def drop_table(self, name: str) -> None: # pragma: no cover
Expand Down Expand Up @@ -95,9 +101,11 @@ def stop(self) -> None:

@property
def num_temp_tables_created(self) -> int:
return len(self.ref_count_map)
with self.lock:
return len(self.ref_count_map)

@property
def num_temp_tables_cleaned(self) -> int:
# TODO SNOW-1662536: we may need a separate counter for the number of tables cleaned when parameter is enabled
return sum(v == 0 for v in self.ref_count_map.values())
with self.lock:
return sum(v == 0 for v in self.ref_count_map.values())
Loading
Loading