1+ """
2+ Unit test to validate that all Thrift-generated field IDs comply with the maximum limit.
3+
4+ Field IDs in Thrift must stay below 3329 to avoid conflicts with reserved ranges
5+ and ensure compatibility with various Thrift implementations and protocols.
6+ """
7+
8+ import inspect
9+ import pytest
10+ import unittest
11+
12+ from databricks .sql .thrift_api .TCLIService import ttypes
13+
14+
15+ class TestThriftFieldIds (unittest .TestCase ):
16+ """Test suite for validating Thrift field ID constraints."""
17+
18+ MAX_ALLOWED_FIELD_ID = 3329
19+
20+ # Known exceptions that exceed the field ID limit
21+ KNOWN_EXCEPTIONS = {
22+ ('TExecuteStatementReq' , 'enforceEmbeddedSchemaCorrectness' ): 3353 ,
23+ ('TSessionHandle' , 'serverProtocolVersion' ): 3329 ,
24+ }
25+
26+ def test_all_thrift_field_ids_are_within_allowed_range (self ):
27+ """
28+ Validates that all field IDs in Thrift-generated classes are within the allowed range.
29+
30+ This test prevents field ID conflicts and ensures compatibility with different
31+ Thrift implementations and protocols.
32+ """
33+ violations = []
34+
35+ # Get all classes from the ttypes module
36+ for name , obj in inspect .getmembers (ttypes ):
37+ if (inspect .isclass (obj ) and
38+ hasattr (obj , 'thrift_spec' ) and
39+ obj .thrift_spec is not None ):
40+
41+ self ._check_class_field_ids (obj , name , violations )
42+
43+ if violations :
44+ error_message = self ._build_error_message (violations )
45+ self .fail (error_message )
46+
47+ def _check_class_field_ids (self , cls , class_name , violations ):
48+ """
49+ Checks all field IDs in a Thrift class and reports violations.
50+
51+ Args:
52+ cls: The Thrift class to check
53+ class_name: Name of the class for error reporting
54+ violations: List to append violation messages to
55+ """
56+ thrift_spec = cls .thrift_spec
57+
58+ if not isinstance (thrift_spec , (tuple , list )):
59+ return
60+
61+ for spec_entry in thrift_spec :
62+ if spec_entry is None :
63+ continue
64+
65+ # Thrift spec format: (field_id, field_type, field_name, ...)
66+ if isinstance (spec_entry , (tuple , list )) and len (spec_entry ) >= 3 :
67+ field_id = spec_entry [0 ]
68+ field_name = spec_entry [2 ]
69+
70+ # Skip known exceptions
71+ if (class_name , field_name ) in self .KNOWN_EXCEPTIONS :
72+ continue
73+
74+ if isinstance (field_id , int ) and field_id >= self .MAX_ALLOWED_FIELD_ID :
75+ violations .append (
76+ "{} field '{}' has field ID {} (exceeds maximum of {})" .format (
77+ class_name , field_name , field_id , self .MAX_ALLOWED_FIELD_ID - 1
78+ )
79+ )
80+
81+ def _build_error_message (self , violations ):
82+ """
83+ Builds a comprehensive error message for field ID violations.
84+
85+ Args:
86+ violations: List of violation messages
87+
88+ Returns:
89+ Formatted error message
90+ """
91+ error_message = (
92+ "Found Thrift field IDs that exceed the maximum allowed value of {}.\n "
93+ "This can cause compatibility issues and conflicts with reserved ID ranges.\n "
94+ "Violations found:\n " .format (self .MAX_ALLOWED_FIELD_ID - 1 )
95+ )
96+
97+ for violation in violations :
98+ error_message += " - {}\n " .format (violation )
99+
100+ return error_message
0 commit comments