diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 765d59cf9..dd5bd6c20 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -15,7 +15,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.11"] # 3.x disabled b/c of 3.12 test failures w/ GRPC. + python-version: ["3.12"] # 3.x disabled b/c of 3.13 test failures w/ JAX. suffix: ["core", "benchmarks", "algorithms", "clients", "pyglove", "raytune"] include: - suffix: "clients" diff --git a/.github/workflows/pypi-publish-dev.yml b/.github/workflows/pypi-publish-dev.yml index 5c6f21088..8315fca9c 100644 --- a/.github/workflows/pypi-publish-dev.yml +++ b/.github/workflows/pypi-publish-dev.yml @@ -17,7 +17,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v4 with: - python-version: '3.11' + python-version: '3.12' - name: Install dependencies # NOTE: grpcio-tools needs to be periodically updated to support later Python versions. run: | diff --git a/.github/workflows/pypi-publish.yml b/.github/workflows/pypi-publish.yml index 0996b2f7c..ea9a84ee3 100644 --- a/.github/workflows/pypi-publish.yml +++ b/.github/workflows/pypi-publish.yml @@ -15,7 +15,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v4 with: - python-version: '3.11' + python-version: '3.12' - name: Install dependencies # NOTE: grpcio-tools needs to be periodically updated to support later Python versions. run: | diff --git a/vizier/__init__.py b/vizier/__init__.py index a57199729..add6643eb 100644 --- a/vizier/__init__.py +++ b/vizier/__init__.py @@ -23,4 +23,4 @@ sys.path.append(PROTO_ROOT) -__version__ = "0.1.23" +__version__ = "0.1.24" diff --git a/vizier/_src/service/performance_test.py b/vizier/_src/service/performance_test.py index f4ae13301..af0d351a9 100644 --- a/vizier/_src/service/performance_test.py +++ b/vizier/_src/service/performance_test.py @@ -18,8 +18,8 @@ import multiprocessing.pool import time -from absl import logging +from absl import logging from vizier._src.service import constants from vizier._src.service import vizier_client from vizier._src.service import vizier_server @@ -41,23 +41,15 @@ def setUpClass(cls): ) vizier_client.environment_variables.server_endpoint = cls.server.endpoint - @parameterized.parameters( - (1, 10, 2), - (2, 10, 2), - (10, 10, 2), - (50, 5, 2), - (100, 5, 2), - ) + @parameterized.parameters((1, 10), (2, 10), (10, 10), (50, 5), (100, 5)) def test_multiple_clients_basic( - self, num_simultaneous_clients, num_trials_per_client, dimension + self, num_simultaneous_clients, num_trials_per_client ): def fn(client_id: int): - experimenter = experimenters.BBOBExperimenterFactory( - 'Sphere', dimension - )() + experimenter = experimenters.BBOBExperimenterFactory('Sphere', 2)() problem_statement = experimenter.problem_statement() study_config = pyvizier.StudyConfig.from_problem(problem_statement) - study_config.algorithm = pyvizier.Algorithm.NSGA2 + study_config.algorithm = pyvizier.Algorithm.RANDOM_SEARCH client = vizier_client.create_or_load_study( owner_id='my_username', diff --git a/vizier/_src/service/ram_datastore.py b/vizier/_src/service/ram_datastore.py index 4048a5eb7..b603867d3 100644 --- a/vizier/_src/service/ram_datastore.py +++ b/vizier/_src/service/ram_datastore.py @@ -18,13 +18,14 @@ For debugging/testing purposes mainly. """ + import collections import copy import dataclasses import threading from typing import Callable, DefaultDict, Dict, Iterable, List, Optional -from absl import logging +from absl import logging from vizier._src.service import custom_errors from vizier._src.service import datastore from vizier._src.service import key_value_pb2 @@ -109,53 +110,55 @@ def create_study(self, study: study_pb2.Study) -> resources.StudyResource: def load_study(self, study_name: str) -> study_pb2.Study: resource = resources.StudyResource.from_name(study_name) - try: - with self._lock: + + with self._lock: + try: return copy.deepcopy( self._owners[resource.owner_id] .studies[resource.study_id] .study_proto ) - except KeyError as err: - raise custom_errors.NotFoundError( - 'Could not get Study with name:', resource.name - ) from err + except KeyError as err: + raise custom_errors.NotFoundError( + 'Could not get Study with name:', resource.name + ) from err def update_study(self, study: study_pb2.Study) -> resources.StudyResource: resource = resources.StudyResource.from_name(study.name) - try: - with self._lock: + with self._lock: + try: self._owners[resource.owner_id].studies[ resource.study_id ].study_proto.CopyFrom(study) - return resource - except KeyError as err: - raise custom_errors.NotFoundError( - 'Could not update Study with name:', resource.name - ) from err + return resource + except KeyError as err: + raise custom_errors.NotFoundError( + 'Could not update Study with name:', resource.name + ) from err def delete_study(self, study_name: str) -> None: resource = resources.StudyResource.from_name(study_name) - try: - with self._lock: + with self._lock: + try: del self._owners[resource.owner_id].studies[resource.study_id] - except KeyError as err: - raise custom_errors.NotFoundError( - 'Study does not exist:', study_name - ) from err + except KeyError as err: + raise custom_errors.NotFoundError( + 'Study does not exist:', study_name + ) from err def list_studies(self, owner_name: str) -> List[study_pb2.Study]: resource = resources.OwnerResource.from_name(owner_name) - try: - with self._lock: + + with self._lock: + try: study_nodes = list(self._owners[resource.owner_id].studies.values()) return copy.deepcopy( [study_node.study_proto for study_node in study_nodes] ) - except KeyError as err: - raise custom_errors.NotFoundError( - 'Owner does not exist:', owner_name - ) from err + except KeyError as err: + raise custom_errors.NotFoundError( + 'Owner does not exist:', owner_name + ) from err def create_trial(self, trial: study_pb2.Trial) -> resources.TrialResource: resource = resources.TrialResource.from_name(trial.name) @@ -175,22 +178,23 @@ def create_trial(self, trial: study_pb2.Trial) -> resources.TrialResource: def get_trial(self, trial_name: str) -> study_pb2.Trial: resource = resources.TrialResource.from_name(trial_name) - try: - with self._lock: + + with self._lock: + try: return copy.deepcopy( self._owners[resource.owner_id] .studies[resource.study_id] .trial_protos[resource.trial_id] ) - except KeyError as err: - raise custom_errors.NotFoundError( - 'Could not get Trial with name:', resource.name - ) from err + except KeyError as err: + raise custom_errors.NotFoundError( + 'Could not get Trial with name:', resource.name + ) from err def update_trial(self, trial: study_pb2.Trial) -> resources.TrialResource: resource = resources.TrialResource.from_name(trial.name) - try: - with self._lock: + with self._lock: + try: trial_protos = ( self._owners[resource.owner_id] .studies[resource.study_id] @@ -201,16 +205,17 @@ def update_trial(self, trial: study_pb2.Trial) -> resources.TrialResource: 'Trial %s does not exist.' % trial.name ) trial_protos[resource.trial_id] = copy.deepcopy(trial) - return resource - except KeyError as err: - raise custom_errors.NotFoundError( - 'Could not update Trial with name:', resource.name - ) from err + except KeyError as err: + raise custom_errors.NotFoundError( + 'Could not update Trial with name:', resource.name + ) from err + + return resource def list_trials(self, study_name: str) -> List[study_pb2.Trial]: resource = resources.StudyResource.from_name(study_name) - try: - with self._lock: + with self._lock: + try: return copy.deepcopy( list( self._owners[resource.owner_id] @@ -218,29 +223,29 @@ def list_trials(self, study_name: str) -> List[study_pb2.Trial]: .trial_protos.values() ) ) - except KeyError as err: - raise custom_errors.NotFoundError( - 'Study does not exist:', study_name - ) from err + except KeyError as err: + raise custom_errors.NotFoundError( + 'Study does not exist:', study_name + ) from err def delete_trial(self, trial_name: str) -> None: resource = resources.TrialResource.from_name(trial_name) - try: - with self._lock: + with self._lock: + try: del ( self._owners[resource.owner_id] .studies[resource.study_id] .trial_protos[resource.trial_id] ) - except KeyError as err: - raise custom_errors.NotFoundError( - 'Trial does not exist:', trial_name - ) from err + except KeyError as err: + raise custom_errors.NotFoundError( + 'Trial does not exist:', trial_name + ) from err def max_trial_id(self, study_name: str) -> int: resource = resources.StudyResource.from_name(study_name) - try: - with self._lock: + with self._lock: + try: trial_ids = copy.deepcopy( list( self._owners[resource.owner_id] @@ -248,15 +253,12 @@ def max_trial_id(self, study_name: str) -> int: .trial_protos.keys() ) ) - except KeyError as err: - raise custom_errors.NotFoundError( - 'Study does not exist:', study_name - ) from err + except KeyError as err: + raise custom_errors.NotFoundError( + 'Study does not exist:', study_name + ) from err - if trial_ids: - return max(trial_ids) - else: - return 0 + return max(trial_ids) if trial_ids else 0 def create_suggestion_operation( self, operation: operations_pb2.Operation @@ -291,8 +293,9 @@ def get_suggestion_operation( self, operation_name: str ) -> operations_pb2.Operation: resource = resources.SuggestionOperationResource.from_name(operation_name) - try: - with self._lock: + + with self._lock: + try: return copy.deepcopy( self._owners[resource.owner_id] .studies[resource.study_id] @@ -300,27 +303,29 @@ def get_suggestion_operation( .suggestion_operations[resource.operation_id] ) - except KeyError as err: - raise custom_errors.NotFoundError( - 'Could not find SuggestionOperation with name:', resource.name - ) from err + except KeyError as err: + raise custom_errors.NotFoundError( + 'Could not find SuggestionOperation with name:', resource.name + ) from err def update_suggestion_operation( self, operation: operations_pb2.Operation ) -> resources.SuggestionOperationResource: resource = resources.SuggestionOperationResource.from_name(operation.name) - try: - with self._lock: + + with self._lock: + try: self._owners[resource.owner_id].studies[resource.study_id].clients[ resource.client_id ].suggestion_operations[resource.operation_id] = copy.deepcopy( operation ) + except KeyError as err: + raise custom_errors.NotFoundError( + 'Could not update SuggestionOperation with name:', resource.name + ) from err + return resource - except KeyError as err: - raise custom_errors.NotFoundError( - 'Could not update SuggestionOperation with name:', resource.name - ) from err def list_suggestion_operations( self, @@ -329,8 +334,9 @@ def list_suggestion_operations( filter_fn: Optional[Callable[[operations_pb2.Operation], bool]] = None, ) -> List[operations_pb2.Operation]: resource = resources.StudyResource.from_name(study_name) - try: - with self._lock: + + with self._lock: + try: operations_list = copy.deepcopy( list( self._owners[resource.owner_id] @@ -339,22 +345,23 @@ def list_suggestion_operations( .suggestion_operations.values() ) ) - except KeyError as err: - raise custom_errors.NotFoundError( - '(study_name, client_id) does not exist:', (study_name, client_id) - ) from err + except KeyError as err: + raise custom_errors.NotFoundError( + '(study_name, client_id) does not exist:', (study_name, client_id) + ) from err - if filter_fn is not None: - return copy.deepcopy([op for op in operations_list if filter_fn(op)]) - else: - return copy.deepcopy(operations_list) + if filter_fn is not None: + return copy.deepcopy([op for op in operations_list if filter_fn(op)]) + else: + return copy.deepcopy(operations_list) def max_suggestion_operation_number( self, study_name: str, client_id: str ) -> int: resource = resources.StudyResource.from_name(study_name) - try: - with self._lock: + + with self._lock: + try: ops = ( self._owners[resource.owner_id] .studies[resource.study_id] @@ -362,10 +369,10 @@ def max_suggestion_operation_number( .suggestion_operations ) return len(ops) - except KeyError as err: - raise custom_errors.NotFoundError( - '(study_name, client_id) does not exist:', (study_name, client_id) - ) from err + except KeyError as err: + raise custom_errors.NotFoundError( + '(study_name, client_id) does not exist:', (study_name, client_id) + ) from err def create_early_stopping_operation( self, operation: vizier_oss_pb2.EarlyStoppingOperation @@ -393,17 +400,17 @@ def get_early_stopping_operation( resource = resources.EarlyStoppingOperationResource.from_name( operation_name ) - try: - with self._lock: + with self._lock: + try: return copy.deepcopy( self._owners[resource.owner_id] .studies[resource.study_id] .early_stopping_operations[resource.operation_id] ) - except KeyError as err: - raise custom_errors.NotFoundError( - 'Could not find EarlyStoppingOperation with name:', resource.name - ) from err + except KeyError as err: + raise custom_errors.NotFoundError( + 'Could not find EarlyStoppingOperation with name:', resource.name + ) from err def update_early_stopping_operation( self, operation: vizier_oss_pb2.EarlyStoppingOperation @@ -411,18 +418,18 @@ def update_early_stopping_operation( resource = resources.EarlyStoppingOperationResource.from_name( operation.name ) - try: - with self._lock: + with self._lock: + try: self._owners[resource.owner_id].studies[ resource.study_id ].early_stopping_operations[resource.operation_id] = copy.deepcopy( operation ) - return resource - except KeyError as err: - raise custom_errors.NotFoundError( - 'Could not update EarlyStoppingOperation with name:', resource.name - ) from err + except KeyError as err: + raise custom_errors.NotFoundError( + 'Could not update EarlyStoppingOperation with name:', resource.name + ) from err + return resource def update_metadata( self, diff --git a/vizier/_src/service/sql_datastore.py b/vizier/_src/service/sql_datastore.py index 094282d59..7f14bd84c 100644 --- a/vizier/_src/service/sql_datastore.py +++ b/vizier/_src/service/sql_datastore.py @@ -103,32 +103,31 @@ def create_study(self, study: study_pb2.Study) -> resources.StudyResource: with self._lock: try: - self._connection.execute(owner_query) - self._connection.commit() + self._write_or_rollback(owner_query) except sqla.exc.IntegrityError: logging.info('Owner with name %s currently exists.', owner_name) - self._connection.rollback() + try: - self._connection.execute(study_query) - self._connection.commit() - return study_resource + self._write_or_rollback(study_query) except sqla.exc.IntegrityError as e: - self._connection.rollback() raise AlreadyExistsError( 'Study with name %s already exists.' % study.name ) from e + self._connection.commit() + + return study_resource def load_study(self, study_name: str) -> study_pb2.Study: query = sqla.select(self._studies_table) query = query.where(self._studies_table.c.study_name == study_name) with self._lock: - result = self._connection.execute(query) + row = self._connection.execute(query).fetchone() + if not row: + raise NotFoundError('Failed to find study name: %s' % study_name) + study = study_pb2.Study.FromString(row.serialized_study) - row = result.fetchone() - if not row: - raise NotFoundError('Failed to find study name: %s' % study_name) - return study_pb2.Study.FromString(row.serialized_study) + return study def update_study(self, study: study_pb2.Study) -> resources.StudyResource: study_resource = resources.StudyResource.from_name(study.name) @@ -151,8 +150,9 @@ def update_study(self, study: study_pb2.Study) -> resources.StudyResource: with self._lock: if not self._connection.execute(eq).fetchone()[0]: raise NotFoundError('Study %s does not exist.' % study.name) - self._connection.execute(uq) + self._write_or_rollback(uq) self._connection.commit() + return study_resource def delete_study(self, study_name: str) -> None: @@ -175,8 +175,8 @@ def delete_study(self, study_name: str) -> None: with self._lock: if not self._connection.execute(eq).fetchone()[0]: raise NotFoundError('Study %s does not exist.' % study_name) - self._connection.execute(dsq) - self._connection.execute(dtq) + self._write_or_rollback(dsq) + self._write_or_rollback(dtq) self._connection.commit() def list_studies(self, owner_name: str) -> List[study_pb2.Study]: @@ -195,8 +195,11 @@ def list_studies(self, owner_name: str) -> List[study_pb2.Study]: if not self._connection.execute(eq).fetchone()[0]: raise NotFoundError('Owner name %s does not exist.' % owner_name) result = self._connection.execute(lq).fetchall() + studies = [ + study_pb2.Study.FromString(row.serialized_study) for row in result + ] - return [study_pb2.Study.FromString(row.serialized_study) for row in result] + return studies def create_trial(self, trial: study_pb2.Trial) -> resources.TrialResource: trial_resource = resources.TrialResource.from_name(trial.name) @@ -210,14 +213,14 @@ def create_trial(self, trial: study_pb2.Trial) -> resources.TrialResource: with self._lock: try: - self._connection.execute(query) - self._connection.commit() - return trial_resource + self._write_or_rollback(query) except sqla.exc.IntegrityError as e: - self._connection.rollback() raise AlreadyExistsError( 'Trial with name %s already exists.' % trial.name ) from e + self._connection.commit() + + return trial_resource def get_trial(self, trial_name: str) -> study_pb2.Trial: query = sqla.select(self._trials_table) @@ -225,11 +228,13 @@ def get_trial(self, trial_name: str) -> study_pb2.Trial: with self._lock: result = self._connection.execute(query) + row = result.fetchone() - row = result.fetchone() - if not row: - raise NotFoundError('Failed to find trial name: %s' % trial_name) - return study_pb2.Trial.FromString(row.serialized_trial) + if not row: + raise NotFoundError('Failed to find trial name: %s' % trial_name) + trial = study_pb2.Trial.FromString(row.serialized_trial) + + return trial def update_trial(self, trial: study_pb2.Trial) -> resources.TrialResource: trial_resource = resources.TrialResource.from_name(trial.name) @@ -253,7 +258,7 @@ def update_trial(self, trial: study_pb2.Trial) -> resources.TrialResource: with self._lock: if not self._connection.execute(eq).fetchone()[0]: raise NotFoundError('Trial %s does not exist.' % trial.name) - self._connection.execute(uq) + self._write_or_rollback(uq) self._connection.commit() return trial_resource @@ -275,8 +280,11 @@ def list_trials(self, study_name: str) -> List[study_pb2.Trial]: if not self._connection.execute(eq).fetchone()[0]: raise NotFoundError('Study name %s does not exist.' % study_name) result = self._connection.execute(lq) + trials = [ + study_pb2.Trial.FromString(row.serialized_trial) for row in result + ] - return [study_pb2.Trial.FromString(row.serialized_trial) for row in result] + return trials def delete_trial(self, trial_name: str) -> None: # Exist query @@ -291,7 +299,7 @@ def delete_trial(self, trial_name: str) -> None: with self._lock: if not self._connection.execute(eq).fetchone()[0]: raise NotFoundError('Trial %s does not exist.' % trial_name) - self._connection.execute(dq) + self._write_or_rollback(dq) self._connection.commit() def max_trial_id(self, study_name: str) -> int: @@ -313,9 +321,7 @@ def max_trial_id(self, study_name: str) -> int: raise NotFoundError('Study %s does not exist.' % study_name) potential_trial_id = self._connection.execute(tq).fetchone()[0] - if potential_trial_id is None: - return 0 - return potential_trial_id + return potential_trial_id if potential_trial_id is not None else 0 def create_suggestion_operation( self, operation: operations_pb2.Operation @@ -330,16 +336,16 @@ def create_suggestion_operation( serialized_op=operation.SerializeToString(), ) - try: - with self._lock: - self._connection.execute(query) - self._connection.commit() - return resource - except sqla.exc.IntegrityError as e: - self._connection.rollback() - raise AlreadyExistsError( - 'Suggest Op with name %s already exists.' % operation.name - ) from e + with self._lock: + try: + self._write_or_rollback(query) + except sqla.exc.IntegrityError as e: + raise AlreadyExistsError( + 'Suggest Op with name %s already exists.' % operation.name + ) from e + self._connection.commit() + + return resource def get_suggestion_operation( self, operation_name: str @@ -350,12 +356,15 @@ def get_suggestion_operation( ) with self._lock: - result = self._connection.execute(q) + row = self._connection.execute(q).fetchone() + + if not row: + raise NotFoundError( + 'Failed to find suggest op name: %s' % operation_name + ) + operation = operations_pb2.Operation.FromString(row.serialized_op) - row = result.fetchone() - if not row: - raise NotFoundError('Failed to find suggest op name: %s' % operation_name) - return operations_pb2.Operation.FromString(row.serialized_op) + return operation def update_suggestion_operation( self, operation: operations_pb2.Operation @@ -386,8 +395,9 @@ def update_suggestion_operation( with self._lock: if not self._connection.execute(eq).fetchone()[0]: raise NotFoundError('Suggest op %s does not exist.' % operation.name) - self._connection.execute(uq) + self._write_or_rollback(uq) self._connection.commit() + return resource def list_suggestion_operations( @@ -414,10 +424,10 @@ def list_suggestion_operations( (study_resource.name, client_id), ) result = self._connection.execute(q) - - all_ops = [ - operations_pb2.Operation.FromString(row.serialized_op) for row in result - ] + all_ops = [ + operations_pb2.Operation.FromString(row.serialized_op) + for row in result + ] if filter_fn is None: return all_ops @@ -458,7 +468,9 @@ def max_suggestion_operation_number( raise NotFoundError( 'Could not find (study_name, client_id):', (study_name, client_id) ) - return self._connection.execute(mq).fetchone()[0] + max_op_number = self._connection.execute(mq).fetchone()[0] + + return max_op_number def create_early_stopping_operation( self, operation: vizier_oss_pb2.EarlyStoppingOperation @@ -474,16 +486,16 @@ def create_early_stopping_operation( serialized_op=operation.SerializeToString(), ) - try: - with self._lock: - self._connection.execute(query) - self._connection.commit() - return resource - except sqla.exc.IntegrityError as e: - self._connection.rollback() - raise AlreadyExistsError( - 'Early stopping op with name %s already exists.' % operation.name - ) from e + with self._lock: + try: + self._write_or_rollback(query) + except sqla.exc.IntegrityError as e: + raise AlreadyExistsError( + 'Early stopping op with name %s already exists.' % operation.name + ) from e + self._connection.commit() + + return resource def get_early_stopping_operation( self, operation_name: str @@ -496,12 +508,16 @@ def get_early_stopping_operation( with self._lock: result = self._connection.execute(q) - row = result.fetchone() - if not row: - raise NotFoundError( - 'Failed to find early stopping op name: %s' % operation_name + row = result.fetchone() + if not row: + raise NotFoundError( + 'Failed to find early stopping op name: %s' % operation_name + ) + operation = vizier_oss_pb2.EarlyStoppingOperation.FromString( + row.serialized_op ) - return vizier_oss_pb2.EarlyStoppingOperation.FromString(row.serialized_op) + + return operation def update_early_stopping_operation( self, operation: vizier_oss_pb2.EarlyStoppingOperation @@ -537,7 +553,8 @@ def update_early_stopping_operation( ) self._connection.execute(uq) self._connection.commit() - return resource + + return resource def update_metadata( self, @@ -553,8 +570,7 @@ def update_metadata( sq = sq.where(self._studies_table.c.study_name == study_name) with self._lock: - study_result = self._connection.execute(sq) - row = study_result.fetchone() + row = self._connection.execute(sq).fetchone() if not row: raise NotFoundError('No such study:', s_resource.name) original_study = study_pb2.Study.FromString(row.serialized_study) @@ -567,8 +583,7 @@ def update_metadata( usq = sqla.update(self._studies_table) usq = usq.where(self._studies_table.c.study_name == study_name) usq = usq.values(serialized_study=original_study.SerializeToString()) - self._connection.execute(usq) - self._connection.commit() + self._write_or_rollback(usq) # Split the trial-related metadata by Trial. split_metadata = collections.defaultdict(list) @@ -583,9 +598,9 @@ def update_metadata( # Obtain original trial. otq = sqla.select(self._trials_table) otq = otq.where(self._trials_table.c.trial_name == trial_name) - trial_result = self._connection.execute(otq) - row = trial_result.fetchone() + row = self._connection.execute(otq).fetchone() if not row: + self._connection.rollback() raise NotFoundError('No such trial:', trial_name) original_trial = study_pb2.Trial.FromString(row.serialized_trial) @@ -594,5 +609,22 @@ def update_metadata( utq = sqla.update(self._trials_table) utq = utq.where(self._trials_table.c.trial_name == trial_name) utq = utq.values(serialized_trial=original_trial.SerializeToString()) - self._connection.execute(utq) - self._connection.commit() + self._write_or_rollback(utq) + + # Commit ALL changes if everything went well. + self._connection.commit() + + def _write_or_rollback(self, write_query: sqla.sql.Executable) -> None: + """Wraps connection.execute() to roll back on write query failure. + + Args: + write_query: The write query to execute. + + Raises: + sqla.exc.DatabaseError: Generic database error. + """ + try: + self._connection.execute(write_query) + except sqla.exc.DatabaseError as e: + self._connection.rollback() + raise e diff --git a/vizier/_src/service/vizier_server.py b/vizier/_src/service/vizier_server.py index f938d0243..a6c3ce1b8 100644 --- a/vizier/_src/service/vizier_server.py +++ b/vizier/_src/service/vizier_server.py @@ -22,6 +22,7 @@ from concurrent import futures import datetime import time +from typing import Optional import attr import grpc @@ -48,7 +49,9 @@ class DefaultVizierServer: """ _host: str = attr.field(default='localhost') - _database_url: str = attr.field(default=constants.SQL_LOCAL_URL, kw_only=True) + _database_url: Optional[str] = attr.field( + default=constants.SQL_LOCAL_URL, kw_only=True + ) _policy_factory: pythia.PolicyFactory = attr.field( factory=service_policy_factory_lib.DefaultPolicyFactory, kw_only=True )