Skip to content

Commit

Permalink
SNOW-1418523 Make Session thread safe (#2312)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-aalam authored Oct 14, 2024
1 parent fa7c406 commit 7c22750
Show file tree
Hide file tree
Showing 24 changed files with 1,033 additions and 229 deletions.
3 changes: 2 additions & 1 deletion .github/pull_request_template.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ Please answer these questions before creating your pull request. Thanks!

<!---
In this section, please add a Snowflake Jira issue number.
Note that if a corresponding GitHub issue exists, you should still include
the Snowflake Jira issue number. For example, for GitHub issue
https://github.com/snowflakedb/snowpark-python/issues/1400, you should
Expand All @@ -24,6 +24,7 @@ Please answer these questions before creating your pull request. Thanks!
- [ ] I am adding new credentials
- [ ] I am adding a new dependency
- [ ] If this is a new feature/behavior, I'm adding the Local Testing parity changes.
- [ ] I acknowledge that I have ensured my changes to be thread-safe. Follow the link for more information: [Thread-safe Developer Guidelines](https://docs.google.com/document/d/162d_i4zZ2AfcGRXojj0jByt8EUq-DrSHPPnTa4QvwbA/edit#bookmark=id.e82u4nekq80k)

3. Please describe how your code solves the related issue.

Expand Down
36 changes: 36 additions & 0 deletions .github/workflows/threadsafe-check.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
name: Threadsafe Check

on:
pull_request:
types: [opened, synchronize, labeled, unlabeled, edited]

jobs:
check_threadsafety:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0

- name: Check for modified files
id: changed_files
run: |
echo "changed_files=$(git diff --name-only --diff-filter=ACMRT ${{ github.event.pull_request.base.sha }} ${{ github.sha }} | xargs)" >> $GITHUB_OUTPUT
- name: Verify threadsafety acknowledgement
run: |
CHANGED_FILES="${{ steps.changed_files.outputs.changed_files }}"
# Check if changed files are in snowpark/_internal, snowpark/mock, or snowpark/*.py. We exclude snowpark/modin in this check.
if echo "$CHANGED_FILES" | grep -qE '(src/snowflake/snowpark/_internal|src/snowflake/snowpark/mock|src/snowflake/snowpark/[^/]+\.py)'; then
echo "Checking PR description for thread-safety acknowledgment..."
if [[ "${{ github.event.pull_request.body }}" != *"[x] I acknowledge that I have ensured my changes to be thread-safe"* ]]; then
echo "Thread-safety acknowledgment not found in PR description."
echo "Please acknowledge the threadsafety implications of your changes by adding '[x] I acknowledge that I have ensured my changes to be thread-safe' to the PR description."
exit 1
else
echo "Thread-safety acknowledgment found in PR description."
fi
else
echo "No critical files modified; skipping threadsafety check."
fi
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@

### Snowpark Python API Updates

- Updated `Session` class to be thread-safe. This allows concurrent dataframe transformations, dataframe actions, UDF and store procedure registration, and concurrent file uploads.

#### New Features

### Snowpark pandas API Updates

#### New Features

- Added numpy compatibility support for `np.float_power`, `np.mod`, `np.remainder`, `np.greater`, `np.greater_equal`, `np.less`, `np.less_equal`, `np.not_equal`, and `np.equal`.

## 1.23.0 (2024-10-09)
Expand All @@ -23,7 +26,7 @@
- Added support for file writes. This feature is currently in private preview.
- Added `thread_id` to `QueryRecord` to track the thread id submitting the query history.
- Added support for `Session.stored_procedure_profiler`.
- Added support for 'Service' domain to `session.lineage.trace` API.
- Added support for 'Service' domain to `session.lineage.trace` API.

#### Improvements

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,17 +113,14 @@ def __init__(
session: Session,
query_generator: QueryGenerator,
logical_plans: List[LogicalPlan],
complexity_bounds: Tuple[int, int],
) -> None:
self.session = session
self._query_generator = query_generator
self.logical_plans = logical_plans
self._parent_map = defaultdict(set)
self.complexity_score_lower_bound = (
session.large_query_breakdown_complexity_bounds[0]
)
self.complexity_score_upper_bound = (
session.large_query_breakdown_complexity_bounds[1]
)
self.complexity_score_lower_bound = complexity_bounds[0]
self.complexity_score_upper_bound = complexity_bounds[1]

def apply(self) -> List[LogicalPlan]:
if is_active_transaction(self.session):
Expand Down
14 changes: 8 additions & 6 deletions src/snowflake/snowpark/_internal/compiler/plan_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def should_start_query_compilation(self) -> bool:

def compile(self) -> Dict[PlanQueryType, List[Query]]:
if self.should_start_query_compilation():
session = self._plan.session
# preparation for compilation
# 1. make a copy of the original plan
start_time = time.time()
Expand All @@ -93,7 +94,7 @@ def compile(self) -> Dict[PlanQueryType, List[Query]]:
# 3. apply each optimizations if needed
# CTE optimization
cte_start_time = time.time()
if self._plan.session.cte_optimization_enabled:
if session.cte_optimization_enabled:
repeated_subquery_eliminator = RepeatedSubqueryElimination(
logical_plans, query_generator
)
Expand All @@ -112,9 +113,12 @@ def compile(self) -> Dict[PlanQueryType, List[Query]]:
plot_plan_if_enabled(plan, f"cte_optimized_plan_{i}")

# Large query breakdown
if self._plan.session.large_query_breakdown_enabled:
if session.large_query_breakdown_enabled:
large_query_breakdown = LargeQueryBreakdown(
self._plan.session, query_generator, logical_plans
session,
query_generator,
logical_plans,
session.large_query_breakdown_complexity_bounds,
)
logical_plans = large_query_breakdown.apply()

Expand All @@ -133,7 +137,6 @@ def compile(self) -> Dict[PlanQueryType, List[Query]]:
cte_time = cte_end_time - cte_start_time
large_query_breakdown_time = large_query_breakdown_end_time - cte_end_time
total_time = time.time() - start_time
session = self._plan.session
summary_value = {
TelemetryField.CTE_OPTIMIZATION_ENABLED.value: session.cte_optimization_enabled,
TelemetryField.LARGE_QUERY_BREAKDOWN_ENABLED.value: session.large_query_breakdown_enabled,
Expand All @@ -156,8 +159,7 @@ def compile(self) -> Dict[PlanQueryType, List[Query]]:
return queries
else:
final_plan = self._plan
if self._plan.session.cte_optimization_enabled:
final_plan = final_plan.replace_repeated_subquery_with_cte()
final_plan = final_plan.replace_repeated_subquery_with_cte()
return {
PlanQueryType.QUERIES: final_plan.queries,
PlanQueryType.POST_ACTIONS: final_plan.post_actions,
Expand Down
63 changes: 40 additions & 23 deletions src/snowflake/snowpark/_internal/server_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,8 @@ def __init__(
options: Dict[str, Union[int, str]],
conn: Optional[SnowflakeConnection] = None,
) -> None:
self._lock = threading.RLock()
self._thread_store = threading.local()
self._lower_case_parameters = {k.lower(): v for k, v in options.items()}
self._add_application_parameters()
self._conn = conn if conn else connect(**self._lower_case_parameters)
Expand All @@ -171,7 +173,6 @@ def __init__(

if "password" in self._lower_case_parameters:
self._lower_case_parameters["password"] = None
self._cursor = self._conn.cursor()
self._telemetry_client = TelemetryClient(self._conn)
self._query_listener: Set[QueryHistory] = set()
# The session in this case refers to a Snowflake session, not a
Expand All @@ -184,6 +185,15 @@ def __init__(
"_skip_upload_on_content_match" in signature.parameters
)

@property
def _cursor(self) -> SnowflakeCursor:
if not hasattr(self._thread_store, "cursor"):
self._thread_store.cursor = self._conn.cursor()
self._telemetry_client.send_cursor_created_telemetry(
self.get_session_id(), threading.get_ident()
)
return self._thread_store.cursor

def _add_application_parameters(self) -> None:
if PARAM_APPLICATION not in self._lower_case_parameters:
# Mirrored from snowflake-connector-python/src/snowflake/connector/connection.py#L295
Expand Down Expand Up @@ -211,10 +221,12 @@ def _add_application_parameters(self) -> None:
] = get_version()

def add_query_listener(self, listener: QueryHistory) -> None:
self._query_listener.add(listener)
with self._lock:
self._query_listener.add(listener)

def remove_query_listener(self, listener: QueryHistory) -> None:
self._query_listener.remove(listener)
with self._lock:
self._query_listener.remove(listener)

def close(self) -> None:
if self._conn:
Expand Down Expand Up @@ -253,17 +265,21 @@ def _run_new_describe(
) -> Union[List[ResultMetadata], List["ResultMetadataV2"]]:
result_metadata = run_new_describe(cursor, query)

for listener in filter(
lambda listener: hasattr(listener, "include_describe")
and listener.include_describe,
self._query_listener,
):
query_record = QueryRecord(cursor.sfqid, query, True)
if getattr(listener, "include_thread_id", False):
with self._lock:
for listener in filter(
lambda listener: hasattr(listener, "include_describe")
and listener.include_describe,
self._query_listener,
):
thread_id = (
threading.get_ident()
if getattr(listener, "include_thread_id", False)
else None
)
query_record = QueryRecord(
cursor.sfqid, query, True, threading.get_ident()
cursor.sfqid, query, True, thread_id=thread_id
)
listener._add_query(query_record)
listener._add_query(query_record)

return result_metadata

Expand Down Expand Up @@ -380,17 +396,18 @@ def upload_stream(
raise ex

def notify_query_listeners(self, query_record: QueryRecord) -> None:
for listener in self._query_listener:
if getattr(listener, "include_thread_id", False):
new_record = QueryRecord(
query_record.query_id,
query_record.sql_text,
query_record.is_describe,
thread_id=threading.get_ident(),
)
listener._add_query(new_record)
else:
listener._add_query(query_record)
with self._lock:
for listener in self._query_listener:
if getattr(listener, "include_thread_id", False):
new_record = QueryRecord(
query_record.query_id,
query_record.sql_text,
query_record.is_describe,
thread_id=threading.get_ident(),
)
listener._add_query(new_record)
else:
listener._add_query(query_record)

def execute_and_notify_query_listener(
self, query: str, **kwargs: Any
Expand Down
18 changes: 18 additions & 0 deletions src/snowflake/snowpark/_internal/telemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#

import functools
import threading
from enum import Enum, unique
from typing import Any, Dict, List, Optional

Expand Down Expand Up @@ -38,6 +39,7 @@ class TelemetryField(Enum):
TYPE_PERFORMANCE_DATA = "snowpark_performance_data"
TYPE_FUNCTION_USAGE = "snowpark_function_usage"
TYPE_SESSION_CREATED = "snowpark_session_created"
TYPE_CURSOR_CREATED = "snowpark_cursor_created"
TYPE_SQL_SIMPLIFIER_ENABLED = "snowpark_sql_simplifier_enabled"
TYPE_CTE_OPTIMIZATION_ENABLED = "snowpark_cte_optimization_enabled"
# telemetry for optimization that eliminates the extra cast expression generated for expressions
Expand Down Expand Up @@ -90,6 +92,8 @@ class TelemetryField(Enum):
TEMP_TABLE_CLEANUP_ABNORMAL_EXCEPTION_MESSAGE = (
"temp_table_cleanup_abnormal_exception_message"
)
# multi-threading
THREAD_IDENTIFIER = "thread_ident"


# These DataFrame APIs call other DataFrame APIs
Expand Down Expand Up @@ -196,6 +200,7 @@ def wrap(*args, **kwargs):
key.value: value
for key, value in plan.cumulative_node_complexity.items()
}
api_calls[0][TelemetryField.THREAD_IDENTIFIER.value] = threading.get_ident()
except Exception:
pass
args[0]._session._conn._telemetry_client.send_function_usage_telemetry(
Expand Down Expand Up @@ -343,6 +348,7 @@ def send_upload_file_perf_telemetry(
TelemetryField.KEY_CATEGORY.value: TelemetryField.PERF_CAT_UPLOAD_FILE.value,
TelemetryField.KEY_FUNC_NAME.value: func_name,
TelemetryField.KEY_DURATION.value: duration,
TelemetryField.THREAD_IDENTIFIER.value: threading.get_ident(),
},
}
self.send(message)
Expand Down Expand Up @@ -537,3 +543,15 @@ def send_large_query_breakdown_update_complexity_bounds(
},
}
self.send(message)

def send_cursor_created_telemetry(self, session_id: int, thread_id: int):
message = {
**self._create_basic_telemetry_data(
TelemetryField.TYPE_CURSOR_CREATED.value
),
TelemetryField.KEY_DATA.value: {
TelemetryField.SESSION_ID.value: session_id,
TelemetryField.THREAD_IDENTIFIER.value: thread_id,
},
}
self.send(message)
22 changes: 15 additions & 7 deletions src/snowflake/snowpark/_internal/temp_table_auto_cleaner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved.
#
import logging
import threading
import weakref
from collections import defaultdict
from typing import TYPE_CHECKING, Dict
Expand Down Expand Up @@ -31,9 +32,12 @@ def __init__(self, session: "Session") -> None:
# to its reference count for later temp table management
# this dict will still be maintained even if the cleaner is stopped (`stop()` is called)
self.ref_count_map: Dict[str, int] = defaultdict(int)
# Lock to protect the ref_count_map
self.lock = threading.RLock()

def add(self, table: SnowflakeTable) -> None:
self.ref_count_map[table.name] += 1
with self.lock:
self.ref_count_map[table.name] += 1
# the finalizer will be triggered when it gets garbage collected
# and this table will be dropped finally
_ = weakref.finalize(table, self._delete_ref_count, table.name)
Expand All @@ -43,18 +47,20 @@ def _delete_ref_count(self, name: str) -> None: # pragma: no cover
Decrements the reference count of a temporary table,
and if the count reaches zero, puts this table in the queue for cleanup.
"""
self.ref_count_map[name] -= 1
if self.ref_count_map[name] == 0:
with self.lock:
self.ref_count_map[name] -= 1
current_ref_count = self.ref_count_map[name]
if current_ref_count == 0:
if (
self.session.auto_clean_up_temp_table_enabled
# if the session is already closed before garbage collection,
# we have no way to drop the table
and not self.session._conn.is_closed()
):
self.drop_table(name)
elif self.ref_count_map[name] < 0:
elif current_ref_count < 0:
logging.debug(
f"Unexpected reference count {self.ref_count_map[name]} for table {name}"
f"Unexpected reference count {current_ref_count} for table {name}"
)

def drop_table(self, name: str) -> None: # pragma: no cover
Expand Down Expand Up @@ -95,9 +101,11 @@ def stop(self) -> None:

@property
def num_temp_tables_created(self) -> int:
return len(self.ref_count_map)
with self.lock:
return len(self.ref_count_map)

@property
def num_temp_tables_cleaned(self) -> int:
# TODO SNOW-1662536: we may need a separate counter for the number of tables cleaned when parameter is enabled
return sum(v == 0 for v in self.ref_count_map.values())
with self.lock:
return sum(v == 0 for v in self.ref_count_map.values())
Loading

0 comments on commit 7c22750

Please sign in to comment.