Skip to content

Commit 107d246

Browse files
KuuCiv-chen_data
and
v-chen_data
authored
Insufficient Permissions Error when trying to access table (mosaicml#1555)
Co-authored-by: v-chen_data <v-chen_data@example.com>
1 parent ee45600 commit 107d246

File tree

4 files changed

+103
-99
lines changed

4 files changed

+103
-99
lines changed

llmfoundry/command_utils/data_prep/convert_delta_to_json.py

Lines changed: 51 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -234,27 +234,7 @@ def run_query(
234234
elif method == 'dbconnect':
235235
if spark == None:
236236
raise ValueError(f'sparkSession is required for dbconnect')
237-
238-
try:
239-
df = spark.sql(query)
240-
except Exception as e:
241-
from pyspark.errors import AnalysisException
242-
if isinstance(e, AnalysisException):
243-
if 'INSUFFICIENT_PERMISSIONS' in e.message: # pyright: ignore
244-
match = re.search(
245-
r"Schema\s+'([^']+)'",
246-
e.message, # pyright: ignore
247-
)
248-
if match:
249-
schema_name = match.group(1)
250-
action = f'using the schema {schema_name}'
251-
else:
252-
action = 'using the schema'
253-
raise InsufficientPermissionsError(action=action,) from e
254-
raise RuntimeError(
255-
f'Error in querying into schema. Restart sparkSession and try again',
256-
) from e
257-
237+
df = spark.sql(query)
258238
if collect:
259239
return df.collect()
260240
return df
@@ -469,71 +449,66 @@ def fetch(
469449
"""
470450
cursor = dbsql.cursor() if dbsql is not None else None
471451
try:
472-
nrows = get_total_rows(
473-
tablename,
474-
method,
475-
cursor,
476-
sparkSession,
477-
)
478-
except Exception as e:
479-
from pyspark.errors import AnalysisException
480-
if isinstance(e, AnalysisException):
481-
if 'INSUFFICIENT_PERMISSIONS' in e.message: # pyright: ignore
482-
raise InsufficientPermissionsError(
483-
action=f'reading from {tablename}',
484-
) from e
485-
if isinstance(e, InsufficientPermissionsError):
486-
raise e
487-
raise RuntimeError(
488-
f'Error in get rows from {tablename}. Restart sparkSession and try again',
489-
) from e
452+
# Get total rows
453+
nrows = get_total_rows(tablename, method, cursor, sparkSession)
490454

491-
try:
455+
# Get columns info
492456
columns, order_by, columns_str = get_columns_info(
493457
tablename,
494458
method,
495459
cursor,
496460
sparkSession,
497461
)
462+
463+
if method == 'dbconnect' and sparkSession is not None:
464+
log.info(f'{processes=}')
465+
df = sparkSession.table(tablename)
466+
467+
# Running the query and collecting the data as arrow or json.
468+
signed, _, _ = df.collect_cf('arrow') # pyright: ignore
469+
log.info(f'len(signed) = {len(signed)}')
470+
471+
args = get_args(signed, json_output_folder, columns)
472+
473+
# Stopping the SparkSession to avoid spilling connection state into the subprocesses.
474+
sparkSession.stop()
475+
476+
with ProcessPoolExecutor(max_workers=processes) as executor:
477+
list(executor.map(download_starargs, args))
478+
479+
elif method == 'dbsql' and cursor is not None:
480+
for start in range(0, nrows, batch_size):
481+
log.warning(f'batch {start}')
482+
end = min(start + batch_size, nrows)
483+
fetch_data(
484+
method,
485+
cursor,
486+
sparkSession,
487+
start,
488+
end,
489+
order_by,
490+
tablename,
491+
columns_str,
492+
json_output_folder,
493+
)
494+
498495
except Exception as e:
499-
raise RuntimeError(
500-
f'Error in get columns from {tablename}. Restart sparkSession and try again',
501-
) from e
496+
from databricks.sql.exc import ServerOperationError
497+
from pyspark.errors import AnalysisException
502498

503-
if method == 'dbconnect' and sparkSession is not None:
504-
log.info(f'{processes=}')
505-
df = sparkSession.table(tablename)
506-
507-
# Running the query and collecting the data as arrow or json.
508-
signed, _, _ = df.collect_cf('arrow') # pyright: ignore
509-
log.info(f'len(signed) = {len(signed)}')
510-
511-
args = get_args(signed, json_output_folder, columns)
512-
513-
# Stopping the SparkSession to avoid spilling connection state into the subprocesses.
514-
sparkSession.stop()
515-
516-
with ProcessPoolExecutor(max_workers=processes) as executor:
517-
list(executor.map(download_starargs, args))
518-
519-
elif method == 'dbsql' and cursor is not None:
520-
for start in range(0, nrows, batch_size):
521-
log.warning(f'batch {start}')
522-
end = min(start + batch_size, nrows)
523-
fetch_data(
524-
method,
525-
cursor,
526-
sparkSession,
527-
start,
528-
end,
529-
order_by,
530-
tablename,
531-
columns_str,
532-
json_output_folder,
533-
)
499+
if isinstance(e, (AnalysisException, ServerOperationError)):
500+
if 'INSUFFICIENT_PERMISSIONS' in str(e):
501+
raise InsufficientPermissionsError(str(e)) from e
502+
503+
if isinstance(e, InsufficientPermissionsError):
504+
raise
505+
506+
# For any other exception, raise a general error
507+
raise RuntimeError(f'Error processing {tablename}: {str(e)}') from e
534508

535-
if cursor is not None:
536-
cursor.close()
509+
finally:
510+
if cursor is not None:
511+
cursor.close()
537512

538513

539514
def validate_and_get_cluster_info(

llmfoundry/utils/exceptions.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,13 @@ def __init__(
456456
class InsufficientPermissionsError(UserError):
457457
"""Error thrown when the user does not have sufficient permissions."""
458458

459-
def __init__(self, action: str) -> None:
460-
message = f'Insufficient permissions when {action}. Please check your permissions.'
461-
super().__init__(message, action=action)
459+
def __init__(self, message: str) -> None:
460+
self.message = message
461+
super().__init__(message)
462+
463+
def __reduce__(self):
464+
# Return a tuple of class, a tuple of arguments, and optionally state
465+
return (InsufficientPermissionsError, (self.message,))
466+
467+
def __str__(self):
468+
return self.message

tests/a_scripts/data_prep/test_convert_delta_to_json.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from llmfoundry.command_utils.data_prep.convert_delta_to_json import (
1111
InsufficientPermissionsError,
1212
download,
13+
fetch,
1314
fetch_DT,
1415
format_tablename,
1516
iterative_combine_jsons,
@@ -30,27 +31,33 @@ class MockAnalysisException(Exception):
3031
def __init__(self, message: str):
3132
self.message = message
3233

34+
def __str__(self):
35+
return self.message
36+
3337
with patch.dict('sys.modules', {'pyspark.errors': MagicMock()}):
3438
sys.modules[
3539
'pyspark.errors'
36-
].AnalysisException = MockAnalysisException # pyright: ignore
40+
].AnalysisException = MockAnalysisException # type: ignore
3741

3842
mock_spark = MagicMock()
3943
mock_spark.sql.side_effect = MockAnalysisException(error_message)
4044

4145
with self.assertRaises(InsufficientPermissionsError) as context:
42-
run_query(
43-
'SELECT * FROM table',
46+
fetch(
4447
method='dbconnect',
45-
cursor=None,
46-
spark=mock_spark,
48+
tablename='main.oogabooga',
49+
json_output_folder='/fake/path',
50+
batch_size=1,
51+
processes=1,
52+
sparkSession=mock_spark,
53+
dbsql=None,
4754
)
4855

49-
self.assertIn(
50-
'using the schema main.oogabooga',
56+
self.assertEqual(
5157
str(context.exception),
58+
error_message,
5259
)
53-
mock_spark.sql.assert_called_once_with('SELECT * FROM table')
60+
mock_spark.sql.assert_called()
5461

5562
@patch(
5663
'databricks.sql.connect',

tests/utils/test_exceptions.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import contextlib
55
import inspect
66
import pickle
7-
from typing import Any, Optional
7+
from typing import Any, Optional, get_type_hints
88

99
import pytest
1010

@@ -14,16 +14,30 @@
1414
def create_exception_object(
1515
exception_class: type[foundry_exceptions.BaseContextualError],
1616
):
17-
# get required arg types of exception class by inspecting its __init__ method
1817

19-
if hasattr(inspect, 'get_annotations'):
20-
required_args = inspect.get_annotations( # type: ignore
21-
exception_class.__init__,
22-
) # type: ignore
23-
else:
24-
required_args = exception_class.__init__.__annotations__ # python 3.9 and below
25-
26-
# create a dictionary of required args with default values
18+
def get_init_annotations(cls: type):
19+
try:
20+
return get_type_hints(cls.__init__)
21+
except (AttributeError, TypeError):
22+
# Handle cases where __init__ does not exist or has no annotations
23+
return {}
24+
25+
# First, try to get annotations from the class itself
26+
required_args = get_init_annotations(exception_class)
27+
28+
# If the annotations are empty, look at parent classes
29+
if not required_args:
30+
for parent in exception_class.__bases__:
31+
if parent == object:
32+
break
33+
parent_args = get_init_annotations(parent)
34+
if parent_args:
35+
required_args = parent_args
36+
break
37+
38+
# Remove self, return, and kwargs
39+
required_args.pop('self', None)
40+
required_args.pop('return', None)
2741
required_args.pop('kwargs', None)
2842

2943
def get_default_value(arg_type: Optional[type] = None):
@@ -51,8 +65,6 @@ def get_default_value(arg_type: Optional[type] = None):
5165
return [{'key': 'value'}]
5266
raise ValueError(f'Unsupported arg type: {arg_type}')
5367

54-
required_args.pop('self', None)
55-
required_args.pop('return', None)
5668
kwargs = {
5769
arg: get_default_value(arg_type)
5870
for arg, arg_type in required_args.items()
@@ -80,6 +92,7 @@ def filter_exceptions(possible_exceptions: list[str]):
8092
def test_exception_serialization(
8193
exception_class: type[foundry_exceptions.BaseContextualError],
8294
):
95+
print(f'Testing serialization for {exception_class.__name__}')
8396
excluded_base_classes = [
8497
foundry_exceptions.InternalError,
8598
foundry_exceptions.UserError,
@@ -88,13 +101,15 @@ def test_exception_serialization(
88101
]
89102

90103
exception = create_exception_object(exception_class)
104+
print(f'Created exception object: {exception}')
91105

92106
expect_reduce_error = exception.__class__ in excluded_base_classes
93107
error_context = pytest.raises(
94108
NotImplementedError,
95109
) if expect_reduce_error else contextlib.nullcontext()
96110

97111
exc_str = str(exception)
112+
print(f'Exception string: {exc_str}')
98113
with error_context:
99114
pkl = pickle.dumps(exception)
100115
unpickled_exc = pickle.loads(pkl)

0 commit comments

Comments
 (0)