Skip to content

Commit

Permalink
Merge pull request galaxyproject#16724 from jdavcs/dev_sa20_fix15
Browse files Browse the repository at this point in the history
SQLAlchemy 2.0 upgrades (part 2)
  • Loading branch information
jmchilton authored Oct 17, 2023
2 parents 7c013cc + c320836 commit 9bb21a0
Show file tree
Hide file tree
Showing 14 changed files with 354 additions and 359 deletions.
4 changes: 2 additions & 2 deletions lib/galaxy/datatypes/display_applications/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@ def decode_dataset_user(trans, dataset_hash, user_hash):
# decode dataset id as usual
# decode user id using the dataset create time as the key
dataset_id = trans.security.decode_id(dataset_hash)
dataset = trans.sa_session.query(trans.app.model.HistoryDatasetAssociation).get(dataset_id)
dataset = trans.sa_session.get(trans.app.model.HistoryDatasetAssociation, dataset_id)
assert dataset, "Bad Dataset id provided to decode_dataset_user"
if user_hash in [None, "None"]:
user = None
else:
security = IdEncodingHelper(id_secret=dataset.create_time)
user_id = security.decode_id(user_hash)
user = trans.sa_session.query(trans.app.model.User).get(user_id)
user = trans.sa_session.get(trans.app.model.User, user_id)
assert user, "A Bad user id was passed to decode_dataset_user"
return dataset, user
18 changes: 9 additions & 9 deletions lib/galaxy/managers/model_stores.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,11 @@ def setup_history_export_job(self, request: SetupHistoryExportJob):
include_deleted = request.include_deleted
store_directory = request.store_directory

history = self._sa_session.query(model.History).get(history_id)
history = self._sa_session.get(model.History, history_id)
# symlink files on export, on worker files will tarred up in a dereferenced manner.
with DirectoryModelExportStore(store_directory, app=self._app, export_files="symlink") as export_store:
export_store.export_history(history, include_hidden=include_hidden, include_deleted=include_deleted)
job = self._sa_session.query(model.Job).get(job_id)
job = self._sa_session.get(model.Job, job_id)
job.state = model.Job.states.NEW
with transaction(self._sa_session):
self._sa_session.commit()
Expand Down Expand Up @@ -137,10 +137,10 @@ def prepare_history_content_download(self, request: GenerateHistoryContentDownlo
short_term_storage_target.path
) as export_store:
if request.content_type == HistoryContentType.dataset:
hda = self._sa_session.query(model.HistoryDatasetAssociation).get(request.content_id)
hda = self._sa_session.get(model.HistoryDatasetAssociation, request.content_id)
export_store.add_dataset(hda)
else:
hdca = self._sa_session.query(model.HistoryDatasetCollectionAssociation).get(request.content_id)
hdca = self._sa_session.get(model.HistoryDatasetCollectionAssociation, request.content_id)
export_store.export_collection(
hdca, include_hidden=request.include_hidden, include_deleted=request.include_deleted
)
Expand All @@ -157,7 +157,7 @@ def prepare_invocation_download(self, request: GenerateInvocationDownload):
export_files=export_files,
bco_export_options=self._bco_export_options(request),
)(short_term_storage_target.path) as export_store:
invocation = self._sa_session.query(model.WorkflowInvocation).get(request.invocation_id)
invocation = self._sa_session.get(model.WorkflowInvocation, request.invocation_id)
export_store.export_workflow_invocation(
invocation, include_hidden=request.include_hidden, include_deleted=request.include_deleted
)
Expand All @@ -174,7 +174,7 @@ def write_invocation_to(self, request: WriteInvocationTo):
bco_export_options=self._bco_export_options(request),
user_context=user_context,
)(target_uri) as export_store:
invocation = self._sa_session.query(model.WorkflowInvocation).get(request.invocation_id)
invocation = self._sa_session.get(model.WorkflowInvocation, request.invocation_id)
export_store.export_workflow_invocation(
invocation, include_hidden=request.include_hidden, include_deleted=request.include_deleted
)
Expand All @@ -199,10 +199,10 @@ def write_history_content_to(self, request: WriteHistoryContentTo):
self._app, model_store_format, export_files=export_files, user_context=user_context
)(target_uri) as export_store:
if request.content_type == HistoryContentType.dataset:
hda = self._sa_session.query(model.HistoryDatasetAssociation).get(request.content_id)
hda = self._sa_session.get(model.HistoryDatasetAssociation, request.content_id)
export_store.add_dataset(hda)
else:
hdca = self._sa_session.query(model.HistoryDatasetCollectionAssociation).get(request.content_id)
hdca = self._sa_session.get(model.HistoryDatasetCollectionAssociation, request.content_id)
export_store.export_collection(
hdca, include_hidden=request.include_hidden, include_deleted=request.include_deleted
)
Expand Down Expand Up @@ -267,7 +267,7 @@ def import_model_store(self, request: ImportModelStoreTaskRequest):
)
history_id = request.history_id
if history_id:
history = self._sa_session.query(model.History).get(history_id)
history = self._sa_session.get(model.History, history_id)
else:
history = None
user_context = self._build_user_context(request.user.user_id)
Expand Down
33 changes: 32 additions & 1 deletion lib/galaxy/managers/pages.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
)

from sqlalchemy import (
desc,
false,
or_,
select,
Expand All @@ -42,7 +43,12 @@
ready_galaxy_markdown_for_export,
ready_galaxy_markdown_for_import,
)
from galaxy.model import PageRevision
from galaxy.model import (
Page,
PageRevision,
PageUserShareAssociation,
User,
)
from galaxy.model.base import transaction
from galaxy.model.index_filter_util import (
append_user_filter,
Expand Down Expand Up @@ -631,3 +637,28 @@ def placeholderRenderForSave(trans: ProvidesHistoryContext, item_class, item_id,
def get_page_revision(session: Session, page_id: int):
stmt = select(PageRevision).filter_by(page_id=page_id)
return session.scalars(stmt)


def get_shared_pages(session: Session, user: User):
stmt = (
select(PageUserShareAssociation)
.where(PageUserShareAssociation.user == user)
.join(Page)
.where(Page.deleted == false())
.order_by(desc(Page.update_time))
)
return session.scalars(stmt)


def get_page(session: Session, user: User, slug: str):
stmt = _build_page_query(select(Page), user, slug)
return session.scalar(stmt).first()


def page_exists(session: Session, user: User, slug: str) -> bool:
stmt = _build_page_query(select(Page.id), user, slug)
return session.scalar(stmt).first() is not None


def _build_page_query(select_clause, user: User, slug: str):
return select_clause.where(Page.user == user).where(Page.slug == slug).where(Page.deleted == false()).limit(1)
27 changes: 15 additions & 12 deletions lib/galaxy/managers/quotas.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,19 @@
Union,
)

from sqlalchemy import select

from galaxy import (
model,
util,
)
from galaxy.exceptions import ActionInputError
from galaxy.managers import base
from galaxy.model import (
Group,
Quota,
User,
)
from galaxy.model.base import transaction
from galaxy.quota import DatabaseQuotaAgent
from galaxy.quota._schema import (
Expand Down Expand Up @@ -46,7 +53,8 @@ def quota_agent(self) -> DatabaseQuotaAgent:
def create_quota(self, payload: dict, decode_id=None) -> Tuple[model.Quota, str]:
params = CreateQuotaParams.parse_obj(payload)
create_amount = self._parse_amount(params.amount)
if self.sa_session.query(model.Quota).filter(model.Quota.name == params.name).first():
stmt = select(Quota).where(Quota.name == params.name).limit(1)
if self.sa_session.scalars(stmt).first():
raise ActionInputError(
"Quota names must be unique and a quota with that name already exists, please choose another name."
)
Expand Down Expand Up @@ -74,12 +82,10 @@ def create_quota(self, payload: dict, decode_id=None) -> Tuple[model.Quota, str]
else:
# Create the UserQuotaAssociations
in_users = [
self.sa_session.query(model.User).get(decode_id(x) if decode_id else x)
for x in util.listify(params.in_users)
self.sa_session.get(User, decode_id(x) if decode_id else x) for x in util.listify(params.in_users)
]
in_groups = [
self.sa_session.query(model.Group).get(decode_id(x) if decode_id else x)
for x in util.listify(params.in_groups)
self.sa_session.get(Group, decode_id(x) if decode_id else x) for x in util.listify(params.in_groups)
]
if None in in_users:
raise ActionInputError("One or more invalid user id has been provided.")
Expand Down Expand Up @@ -108,12 +114,10 @@ def _parse_amount(self, amount: str) -> Optional[Union[int, bool]]:
return False

def rename_quota(self, quota, params) -> str:
stmt = select(Quota).where(Quota.name == params.name).limit(1)
if not params.name:
raise ActionInputError("Enter a valid name.")
elif (
params.name != quota.name
and self.sa_session.query(model.Quota).filter(model.Quota.name == params.name).first()
):
elif params.name != quota.name and self.sa_session.scalars(stmt).first():
raise ActionInputError("A quota with that name already exists.")
else:
old_name = quota.name
Expand All @@ -131,13 +135,12 @@ def manage_users_and_groups_for_quota(self, quota, params, decode_id=None) -> st
raise ActionInputError("Default quotas cannot be associated with specific users and groups.")
else:
in_users = [
self.sa_session.query(model.User).get(decode_id(x) if decode_id else x)
for x in util.listify(params.in_users)
self.sa_session.get(model.User, decode_id(x) if decode_id else x) for x in util.listify(params.in_users)
]
if None in in_users:
raise ActionInputError("One or more invalid user id has been provided.")
in_groups = [
self.sa_session.query(model.Group).get(decode_id(x) if decode_id else x)
self.sa_session.get(model.Group, decode_id(x) if decode_id else x)
for x in util.listify(params.in_groups)
]
if None in in_groups:
Expand Down
18 changes: 12 additions & 6 deletions lib/galaxy/managers/roles.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
import logging
from typing import List

from sqlalchemy import false
from sqlalchemy import (
false,
select,
)
from sqlalchemy.orm import exc as sqlalchemy_exceptions

import galaxy.exceptions
Expand Down Expand Up @@ -44,7 +47,8 @@ def get(self, trans: ProvidesUserContext, role_id: int) -> model.Role:
:raises: InconsistentDatabase, RequestParameterInvalidException, InternalServerError
"""
try:
role = self.session().query(self.model_class).filter(self.model_class.id == role_id).one()
stmt = select(self.model_class).where(self.model_class.id == role_id)
role = self.session().execute(stmt).scalar_one()
except sqlalchemy_exceptions.MultipleResultsFound:
raise galaxy.exceptions.InconsistentDatabase("Multiple roles found with the same id.")
except sqlalchemy_exceptions.NoResultFound:
Expand All @@ -59,7 +63,8 @@ def get(self, trans: ProvidesUserContext, role_id: int) -> model.Role:

def list_displayable_roles(self, trans: ProvidesUserContext) -> List[Role]:
roles = []
for role in trans.sa_session.query(Role).filter(Role.deleted == false()):
stmt = select(Role).where(Role.deleted == false())
for role in trans.sa_session.scalars(stmt):
if trans.user_is_admin or trans.app.security_agent.ok_to_display(trans.user, role):
roles.append(role)
return roles
Expand All @@ -70,15 +75,16 @@ def create_role(self, trans: ProvidesUserContext, role_definition_model: RoleDef
user_ids = role_definition_model.user_ids or []
group_ids = role_definition_model.group_ids or []

if trans.sa_session.query(Role).filter(Role.name == name).first():
stmt = select(Role).where(Role.name == name).limit(1)
if trans.sa_session.scalars(stmt).first():
raise RequestParameterInvalidException(f"A role with that name already exists [{name}]")

role_type = Role.types.ADMIN # TODO: allow non-admins to create roles

role = Role(name=name, description=description, type=role_type)
trans.sa_session.add(role)
users = [trans.sa_session.query(model.User).get(i) for i in user_ids]
groups = [trans.sa_session.query(model.Group).get(i) for i in group_ids]
users = [trans.sa_session.get(model.User, i) for i in user_ids]
groups = [trans.sa_session.get(model.Group, i) for i in group_ids]

# Create the UserRoleAssociations
for user in users:
Expand Down
51 changes: 23 additions & 28 deletions lib/galaxy/managers/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
)
from galaxy.model import (
User,
UserAddress,
UserQuotaUsage,
)
from galaxy.model.base import transaction
Expand Down Expand Up @@ -233,13 +234,8 @@ def purge(self, user, flush=True):
user.username = uname_hash
# Redact user addresses as well
if self.app.config.redact_user_address_during_deletion:
user_addresses = (
self.session()
.query(self.app.model.UserAddress)
.filter(self.app.model.UserAddress.user_id == user.id)
.all()
)
for addr in user_addresses:
stmt = select(UserAddress).where(UserAddress.user_id == user.id)
for addr in self.session().scalars(stmt):
addr.desc = new_secure_hash_v2(addr.desc + pseudorandom_value)
addr.name = new_secure_hash_v2(addr.name + pseudorandom_value)
addr.institution = new_secure_hash_v2(addr.institution + pseudorandom_value)
Expand All @@ -264,7 +260,7 @@ def _error_on_duplicate_email(self, email: str) -> None:
raise exceptions.Conflict("Email must be unique", email=email)

def by_id(self, user_id: int) -> model.User:
return self.app.model.session.query(self.model_class).get(user_id)
return self.app.model.session.get(self.model_class, user_id)

# ---- filters
def by_email(self, email: str, filters=None, **kwargs) -> Optional[model.User]:
Expand All @@ -286,7 +282,8 @@ def by_api_key(self, api_key: str, sa_session=None):
return schema.BootstrapAdminUser()
sa_session = sa_session or self.app.model.session
try:
provided_key = sa_session.query(self.app.model.APIKeys).filter_by(key=api_key, deleted=False).one()
stmt = select(self.app.model.APIKeys).filter_by(key=api_key, deleted=False)
provided_key = sa_session.execute(stmt).scalar_one()
except NoResultFound:
raise exceptions.AuthenticationFailed("Provided API key is not valid.")
if provided_key.user.deleted:
Expand Down Expand Up @@ -363,12 +360,7 @@ def get_user_by_identity(self, identity):
user = get_user_by_email(self.session(), identity, self.model_class)
if not user:
# Try a case-insensitive match on the email
user = (
self.session()
.query(self.model_class)
.filter(func.lower(self.model_class.table.c.email) == identity.lower())
.first()
)
user = self._get_user_by_email_case_insensitive(self.session(), identity)
else:
user = get_user_by_username(self.session(), identity, self.model_class)
return user
Expand Down Expand Up @@ -445,7 +437,7 @@ def change_password(self, trans, password=None, confirm=None, token=None, id=Non
if not token and not id:
return None, "Please provide a token or a user and password."
if token:
token_result = trans.sa_session.query(self.app.model.PasswordResetToken).get(token)
token_result = trans.sa_session.get(self.app.model.PasswordResetToken, token)
if not token_result or not token_result.expiration_time > datetime.utcnow():
return None, "Invalid or expired password reset token, please request a new one."
user = token_result.user
Expand Down Expand Up @@ -483,13 +475,14 @@ def __set_password(self, trans, user, password, confirm):
user.set_password_cleartext(password)
# Invalidate all other sessions
if trans.galaxy_session:
for other_galaxy_session in trans.sa_session.query(self.app.model.GalaxySession).filter(
stmt = select(self.app.model.GalaxySession).where(
and_(
self.app.model.GalaxySession.table.c.user_id == user.id,
self.app.model.GalaxySession.table.c.is_valid == true(),
self.app.model.GalaxySession.table.c.id != trans.galaxy_session.id,
self.app.model.GalaxySession.user_id == user.id,
self.app.model.GalaxySession.is_valid == true(),
self.app.model.GalaxySession.id != trans.galaxy_session.id,
)
):
)
for other_galaxy_session in trans.sa_session.scalars(stmt):
other_galaxy_session.is_valid = False
trans.sa_session.add(other_galaxy_session)
trans.sa_session.add(user)
Expand Down Expand Up @@ -581,11 +574,7 @@ def send_reset_email(self, trans, payload, **kwd):
def get_reset_token(self, trans, email):
reset_user = get_user_by_email(trans.sa_session, email, self.app.model.User)
if not reset_user and email != email.lower():
reset_user = (
trans.sa_session.query(self.app.model.User)
.filter(func.lower(self.app.model.User.table.c.email) == email.lower())
.first()
)
reset_user = self._get_user_by_email_case_insensitive(trans.sa_session, email)
if reset_user:
prt = self.app.model.PasswordResetToken(reset_user)
trans.sa_session.add(prt)
Expand Down Expand Up @@ -644,9 +633,11 @@ def get_or_create_remote_user(self, remote_user_email):
for char in [x for x in username if x not in f"{string.ascii_lowercase + string.digits}-."]:
username = username.replace(char, "-")
# Find a unique username - user can change it later
if self.session().query(self.app.model.User).filter_by(username=username).first():
stmt = select(self.app.model.User).filter_by(username=username).limit(1)
if self.session().scalars(stmt).first():
i = 1
while self.session().query(self.app.model.User).filter_by(username=f"{username}-{str(i)}").first():
stmt = select(self.app.model.User).filter_by(username=f"{username}-{str(i)}").limit(1)
while self.session().scalars(stmt).first():
i += 1
username += f"-{str(i)}"
user.username = username
Expand All @@ -660,6 +651,10 @@ def get_or_create_remote_user(self, remote_user_email):
# self.log_event( "Automatically created account '%s'", user.email )
return user

def _get_user_by_email_case_insensitive(self, session, email):
stmt = select(self.app.model.User).where(func.lower(self.app.model.User.email) == email.lower()).limit(1)
return session.scalars(stmt).first()


class UserSerializer(base.ModelSerializer, deletable.PurgableSerializerMixin):
model_manager_class = UserManager
Expand Down
Loading

0 comments on commit 9bb21a0

Please sign in to comment.