Skip to content

Commit

Permalink
1. Upgrade to Py 3.12
Browse files Browse the repository at this point in the history
2. Fix SQL datastore's commit logic
3. Fix performance_test.py

PiperOrigin-RevId: 721981221
  • Loading branch information
xingyousong authored and copybara-github committed Feb 1, 2025
1 parent 615bb2a commit b4a2b89
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 62 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.11"] # 3.x disabled b/c of 3.12 test failures w/ GRPC.
python-version: ["3.12"] # 3.x disabled b/c of 3.13 test failures w/ JAX.
suffix: ["core", "benchmarks", "algorithms", "clients", "pyglove", "raytune"]
include:
- suffix: "clients"
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/pypi-publish-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.11'
python-version: '3.12'
- name: Install dependencies
# NOTE: grpcio-tools needs to be periodically updated to support later Python versions.
run: |
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/pypi-publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.11'
python-version: '3.12'
- name: Install dependencies
# NOTE: grpcio-tools needs to be periodically updated to support later Python versions.
run: |
Expand Down
18 changes: 5 additions & 13 deletions vizier/_src/service/performance_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@

import multiprocessing.pool
import time
from absl import logging

from absl import logging
from vizier._src.service import constants
from vizier._src.service import vizier_client
from vizier._src.service import vizier_server
Expand All @@ -41,23 +41,15 @@ def setUpClass(cls):
)
vizier_client.environment_variables.server_endpoint = cls.server.endpoint

@parameterized.parameters(
(1, 10, 2),
(2, 10, 2),
(10, 10, 2),
(50, 5, 2),
(100, 5, 2),
)
@parameterized.parameters((1, 10), (2, 10), (10, 10), (50, 5), (100, 5))
def test_multiple_clients_basic(
self, num_simultaneous_clients, num_trials_per_client, dimension
self, num_simultaneous_clients, num_trials_per_client
):
def fn(client_id: int):
experimenter = experimenters.BBOBExperimenterFactory(
'Sphere', dimension
)()
experimenter = experimenters.BBOBExperimenterFactory('Sphere', 2)()
problem_statement = experimenter.problem_statement()
study_config = pyvizier.StudyConfig.from_problem(problem_statement)
study_config.algorithm = pyvizier.Algorithm.NSGA2
study_config.algorithm = pyvizier.Algorithm.RANDOM_SEARCH

client = vizier_client.create_or_load_study(
owner_id='my_username',
Expand Down
106 changes: 61 additions & 45 deletions vizier/_src/service/sql_datastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,20 +103,19 @@ def create_study(self, study: study_pb2.Study) -> resources.StudyResource:

with self._lock:
try:
self._connection.execute(owner_query)
self._connection.commit()
self._write_or_rollback(owner_query)
except sqla.exc.IntegrityError:
logging.info('Owner with name %s currently exists.', owner_name)
self._connection.rollback()

try:
self._connection.execute(study_query)
self._connection.commit()
return study_resource
self._write_or_rollback(study_query)
except sqla.exc.IntegrityError as e:
self._connection.rollback()
raise AlreadyExistsError(
'Study with name %s already exists.' % study.name
) from e
self._connection.commit()

return study_resource

def load_study(self, study_name: str) -> study_pb2.Study:
query = sqla.select(self._studies_table)
Expand Down Expand Up @@ -151,8 +150,9 @@ def update_study(self, study: study_pb2.Study) -> resources.StudyResource:
with self._lock:
if not self._connection.execute(eq).fetchone()[0]:
raise NotFoundError('Study %s does not exist.' % study.name)
self._connection.execute(uq)
self._write_or_rollback(uq)
self._connection.commit()

return study_resource

def delete_study(self, study_name: str) -> None:
Expand All @@ -175,8 +175,8 @@ def delete_study(self, study_name: str) -> None:
with self._lock:
if not self._connection.execute(eq).fetchone()[0]:
raise NotFoundError('Study %s does not exist.' % study_name)
self._connection.execute(dsq)
self._connection.execute(dtq)
self._write_or_rollback(dsq)
self._write_or_rollback(dtq)
self._connection.commit()

def list_studies(self, owner_name: str) -> List[study_pb2.Study]:
Expand Down Expand Up @@ -210,14 +210,14 @@ def create_trial(self, trial: study_pb2.Trial) -> resources.TrialResource:

with self._lock:
try:
self._connection.execute(query)
self._connection.commit()
return trial_resource
self._write_or_rollback(query)
except sqla.exc.IntegrityError as e:
self._connection.rollback()
raise AlreadyExistsError(
'Trial with name %s already exists.' % trial.name
) from e
self._connection.commit()

return trial_resource

def get_trial(self, trial_name: str) -> study_pb2.Trial:
query = sqla.select(self._trials_table)
Expand Down Expand Up @@ -253,7 +253,7 @@ def update_trial(self, trial: study_pb2.Trial) -> resources.TrialResource:
with self._lock:
if not self._connection.execute(eq).fetchone()[0]:
raise NotFoundError('Trial %s does not exist.' % trial.name)
self._connection.execute(uq)
self._write_or_rollback(uq)
self._connection.commit()

return trial_resource
Expand Down Expand Up @@ -291,7 +291,7 @@ def delete_trial(self, trial_name: str) -> None:
with self._lock:
if not self._connection.execute(eq).fetchone()[0]:
raise NotFoundError('Trial %s does not exist.' % trial_name)
self._connection.execute(dq)
self._write_or_rollback(dq)
self._connection.commit()

def max_trial_id(self, study_name: str) -> int:
Expand Down Expand Up @@ -330,16 +330,16 @@ def create_suggestion_operation(
serialized_op=operation.SerializeToString(),
)

try:
with self._lock:
self._connection.execute(query)
self._connection.commit()
return resource
except sqla.exc.IntegrityError as e:
self._connection.rollback()
raise AlreadyExistsError(
'Suggest Op with name %s already exists.' % operation.name
) from e
with self._lock:
try:
self._write_or_rollback(query)
except sqla.exc.IntegrityError as e:
raise AlreadyExistsError(
'Suggest Op with name %s already exists.' % operation.name
) from e
self._connection.commit()

return resource

def get_suggestion_operation(
self, operation_name: str
Expand Down Expand Up @@ -386,8 +386,9 @@ def update_suggestion_operation(
with self._lock:
if not self._connection.execute(eq).fetchone()[0]:
raise NotFoundError('Suggest op %s does not exist.' % operation.name)
self._connection.execute(uq)
self._write_or_rollback(uq)
self._connection.commit()

return resource

def list_suggestion_operations(
Expand Down Expand Up @@ -474,16 +475,16 @@ def create_early_stopping_operation(
serialized_op=operation.SerializeToString(),
)

try:
with self._lock:
self._connection.execute(query)
self._connection.commit()
return resource
except sqla.exc.IntegrityError as e:
self._connection.rollback()
raise AlreadyExistsError(
'Early stopping op with name %s already exists.' % operation.name
) from e
with self._lock:
try:
self._write_or_rollback(query)
except sqla.exc.IntegrityError as e:
raise AlreadyExistsError(
'Early stopping op with name %s already exists.' % operation.name
) from e
self._connection.commit()

return resource

def get_early_stopping_operation(
self, operation_name: str
Expand Down Expand Up @@ -553,8 +554,7 @@ def update_metadata(
sq = sq.where(self._studies_table.c.study_name == study_name)

with self._lock:
study_result = self._connection.execute(sq)
row = study_result.fetchone()
row = self._connection.execute(sq).fetchone()
if not row:
raise NotFoundError('No such study:', s_resource.name)
original_study = study_pb2.Study.FromString(row.serialized_study)
Expand All @@ -567,8 +567,7 @@ def update_metadata(
usq = sqla.update(self._studies_table)
usq = usq.where(self._studies_table.c.study_name == study_name)
usq = usq.values(serialized_study=original_study.SerializeToString())
self._connection.execute(usq)
self._connection.commit()
self._write_or_rollback(usq)

# Split the trial-related metadata by Trial.
split_metadata = collections.defaultdict(list)
Expand All @@ -583,9 +582,9 @@ def update_metadata(
# Obtain original trial.
otq = sqla.select(self._trials_table)
otq = otq.where(self._trials_table.c.trial_name == trial_name)
trial_result = self._connection.execute(otq)
row = trial_result.fetchone()
row = self._connection.execute(otq).fetchone()
if not row:
self._connection.rollback()
raise NotFoundError('No such trial:', trial_name)
original_trial = study_pb2.Trial.FromString(row.serialized_trial)

Expand All @@ -594,5 +593,22 @@ def update_metadata(
utq = sqla.update(self._trials_table)
utq = utq.where(self._trials_table.c.trial_name == trial_name)
utq = utq.values(serialized_trial=original_trial.SerializeToString())
self._connection.execute(utq)
self._connection.commit()
self._write_or_rollback(utq)

# Commit ALL changes if everything went well.
self._connection.commit()

def _write_or_rollback(self, write_query: sqla.sql.Executable) -> None:
"""Wraps connection.execute() to roll back on write query failure.
Args:
write_query: The write query to execute.
Raises:
sqla.exc.DatabaseError: Generic database error.
"""
try:
self._connection.execute(write_query)
except sqla.exc.DatabaseError as e:
self._connection.rollback()
raise e
5 changes: 4 additions & 1 deletion vizier/_src/service/vizier_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from concurrent import futures
import datetime
import time
from typing import Optional

import attr
import grpc
Expand All @@ -48,7 +49,9 @@ class DefaultVizierServer:
"""

_host: str = attr.field(default='localhost')
_database_url: str = attr.field(default=constants.SQL_LOCAL_URL, kw_only=True)
_database_url: Optional[str] = attr.field(
default=constants.SQL_LOCAL_URL, kw_only=True
)
_policy_factory: pythia.PolicyFactory = attr.field(
factory=service_policy_factory_lib.DefaultPolicyFactory, kw_only=True
)
Expand Down

0 comments on commit b4a2b89

Please sign in to comment.