Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 12 additions & 37 deletions cronjobs/src/commands/backport_records.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,15 @@ def parse_querystring(qs):
}


def backport_records(event, context, **kwargs):
def backport_records():
"""Backport records creations, updates and deletions from one collection to another."""
server_url = event["server"]
source_auth = (
event.get("backport_records_source_auth")
or os.environ["BACKPORT_RECORDS_SOURCE_AUTH"]
)
dest_auth = event.get(
"backport_records_dest_auth",
os.getenv("BACKPORT_RECORDS_DEST_AUTH", source_auth),
)
SERVER_URL = os.environ["SERVER"]
source_auth = os.environ["BACKPORT_RECORDS_SOURCE_AUTH"]
dest_auth = os.getenv("BACKPORT_RECORDS_DEST_AUTH", source_auth)

mappings = []

if mappings_env := (
event.get("backport_records_mappings")
or os.getenv("BACKPORT_RECORDS_MAPPINGS", "")
):
if mappings_env := os.getenv("BACKPORT_RECORDS_MAPPINGS", ""):
regexp = re.compile(
r"^(?P<sbid>[^/]+)/(?P<scid>[^/\?]+)(?P<qs>\?.*)? -> (?P<dbid>[^/]+)/(?P<dcid>[^/]+)$"
)
Expand All @@ -54,39 +45,23 @@ def backport_records(event, context, **kwargs):
else:
raise ValueError(f"Invalid syntax in line {entry}")
else:
sbid = (
event.get("backport_records_source_bucket")
or os.environ["BACKPORT_RECORDS_SOURCE_BUCKET"]
)
scid = (
event.get("backport_records_source_collection")
or os.environ["BACKPORT_RECORDS_SOURCE_COLLECTION"]
)
filters_json = event.get("backport_records_source_filters") or os.getenv(
"BACKPORT_RECORDS_SOURCE_FILTERS", ""
)
sbid = os.environ["BACKPORT_RECORDS_SOURCE_BUCKET"]
scid = os.environ["BACKPORT_RECORDS_SOURCE_COLLECTION"]
filters_json = os.getenv("BACKPORT_RECORDS_SOURCE_FILTERS", "")
filters_dict = json.loads(filters_json or "{}")

dbid = event.get(
"backport_records_dest_bucket",
os.getenv("BACKPORT_RECORDS_DEST_BUCKET", sbid),
)
dcid = event.get(
"backport_records_dest_collection",
os.getenv("BACKPORT_RECORDS_DEST_COLLECTION", scid),
)
dbid = os.getenv("BACKPORT_RECORDS_DEST_BUCKET", sbid)
dcid = os.getenv("BACKPORT_RECORDS_DEST_COLLECTION", scid)

if sbid == dbid and scid == dcid:
raise ValueError("Cannot copy records: destination is identical to source")

mappings.append((sbid, scid, filters_dict, dbid, dcid))

safe_headers = event.get(
"safe_headers", config("SAFE_HEADERS", default=False, cast=bool)
)
safe_headers = config("SAFE_HEADERS", default=False, cast=bool)

for mapping in mappings:
execute_backport(server_url, source_auth, dest_auth, safe_headers, *mapping)
execute_backport(SERVER_URL, source_auth, dest_auth, safe_headers, *mapping)


def execute_backport(
Expand Down
8 changes: 3 additions & 5 deletions cronjobs/src/commands/build_bundles.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from . import KintoClient, call_parallel, fetch_all_changesets, retry_timeout


SERVER = os.getenv("SERVER")
BUNDLE_MAX_SIZE_BYTES = int(os.getenv("BUNDLE_MAX_SIZE_BYTES", "20_000_000"))
ENVIRONMENT = os.getenv("ENVIRONMENT", "local")
REALM = os.getenv("REALM", "test")
Expand Down Expand Up @@ -114,7 +113,7 @@ def sync_cloud_storage(
print(f"Deleted gs://{storage_bucket}/{blob.name}")


def build_bundles(event, context):
def build_bundles():
"""
Build and upload bundles of changesets and attachments.

Expand All @@ -126,9 +125,8 @@ def build_bundles(event, context):
- builds `{bid}--{cid}.zip` for each of them
- send the bundles to the Cloud storage bucket
"""
rs_server = event.get("server") or SERVER

client = KintoClient(server_url=rs_server)
server = os.getenv("SERVER")
client = KintoClient(server_url=server)

base_url = client.server_info()["capabilities"]["attachments"]["base_url"]

Expand Down
2 changes: 1 addition & 1 deletion cronjobs/src/commands/expire_orphan_attachments.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
BATCH_SIZE = int(os.getenv("BATCH_SIZE", "100"))


def expire_orphan_attachments(event, context):
def expire_orphan_attachments():
"""
This cronjob will set the custom time field on orphaned attachments to the current time.
We then have a retention policy on GCS bucket that will
Expand Down
6 changes: 4 additions & 2 deletions cronjobs/src/commands/git_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@
GIT_EMAIL = _email.rstrip(">")


def git_export(event, context):
def git_export():
"""
Export Remote Settings data to a Git repository.
"""
Expand Down Expand Up @@ -499,7 +499,9 @@ def process_attachments(
if existing := existing_attachments.get(location):
existing_hash, existing_size = existing
if existing_hash != hash or existing_size != size:
print(f"Bundle {path} {'is new' if existing_hash is None else 'has changed'}")
print(
f"Bundle {path} {'is new' if existing_hash is None else 'has changed'}"
)
changed_attachments.append((hash, size, url))
return changed_attachments, common_content

Expand Down
2 changes: 1 addition & 1 deletion cronjobs/src/commands/purge_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def utcnow():
return datetime.now(timezone.utc)


def purge_history(*args, **kwargs):
def purge_history():
"""Purge old history entries on a regular basis."""
server_url = config("SERVER", default="http://localhost:8888/v1")
auth = config("AUTH", default="admin:s3cr3t")
Expand Down
17 changes: 7 additions & 10 deletions cronjobs/src/commands/refresh_signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,15 @@ def get_signed_source(server_info, change):
}


def refresh_signature(event, context, **kwargs):
def refresh_signature():
"""Refresh the signatures of each collection."""
server_url = event["server"]
auth = event.get("refresh_signature_auth") or os.getenv("REFRESH_SIGNATURE_AUTH")
max_signature_age = int(
event.get("max_signature_age", os.getenv("MAX_SIGNATURE_AGE", 7))
)

server = os.environ["SERVER"]
auth = os.getenv("REFRESH_SIGNATURE_AUTH")
max_signature_age = int(os.getenv("MAX_SIGNATURE_AGE", 7))

# Look at the collections in the changes endpoint.
bucket = event.get("bucket", "monitor")
collection = event.get("collection", "changes")
client = Client(server_url=server_url, bucket=bucket, collection=collection)
client = Client(server_url=server, bucket="monitor", collection="changes")
print("Looking at %s: " % client.get_endpoint("collection"))
changes = client.get_records()

Expand All @@ -67,7 +64,7 @@ def refresh_signature(event, context, **kwargs):
continue

client = Client(
server_url=server_url,
server_url=server,
bucket=source["bucket"],
collection=source["collection"],
auth=auth,
Expand Down
11 changes: 3 additions & 8 deletions cronjobs/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,7 @@ def white_bold(s):
)


def run(command, event=None, context=None):
if event is None:
event = {"server": SERVER_URL}
if context is None:
context = {"sentry_sdk": sentry_sdk}

def run(command):
if isinstance(command, (str,)):
# Import the command module and returns its main function.
mod = importlib.import_module(f"commands.{command}")
Expand All @@ -82,10 +77,10 @@ def run(command, event=None, context=None):
# See https://docs.sentry.io/platforms/python/guides/gcp-functions/

# Option to test failure to test Sentry integration.
if event.get("force_fail") or os.getenv("FORCE_FAIL"):
if os.getenv("FORCE_FAIL"):
raise Exception("Found forced failure flag")

return command(event, context)
return command()


def main(*args):
Expand Down
106 changes: 43 additions & 63 deletions cronjobs/tests/commands/test_backport_records.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import unittest
from unittest import mock

import pytest
import responses
Expand All @@ -8,7 +9,6 @@

class TestRecordsBackport(unittest.TestCase):
server = "https://fake-server.net/v1"
auth = ("foo", "bar")
source_bid = "main"
source_cid = "one"
dest_bid = "main-workspace"
Expand All @@ -25,6 +25,21 @@ def setUp(self):
)
self.dest_records_uri = f"{self.dest_collection_uri}/records"

# Set environment variables.
self.patcher = mock.patch.dict(
"os.environ",
{
"SERVER": self.server,
"BACKPORT_RECORDS_SOURCE_AUTH": "foo:bar",
"BACKPORT_RECORDS_SOURCE_BUCKET": self.source_bid,
"BACKPORT_RECORDS_SOURCE_COLLECTION": self.source_cid,
"BACKPORT_RECORDS_DEST_BUCKET": self.dest_bid,
"BACKPORT_RECORDS_DEST_COLLECTION": self.dest_cid,
},
)
self.patcher.start()
self.addCleanup(unittest.mock.patch.dict, "os.environ", {}, clear=True)

@responses.activate
def test_missing_records_are_backported(self):
responses.add(
Expand Down Expand Up @@ -52,18 +67,13 @@ def test_missing_records_are_backported(self):
)
responses.add(responses.POST, self.server + "/batch", json={"responses": []})

backport_records(
event={
"server": self.server,
"backport_records_source_auth": self.auth,
"backport_records_source_bucket": self.source_bid,
"backport_records_source_collection": self.source_cid,
"backport_records_source_filters": '{"min_age": 20}',
"backport_records_dest_bucket": self.dest_bid,
"backport_records_dest_collection": self.dest_cid,
with mock.patch.dict(
"os.environ",
{
"BACKPORT_RECORDS_SOURCE_FILTERS": '{"min_age": 20}',
},
context=None,
)
):
backport_records()

assert responses.calls[0].request.method == "GET"
assert responses.calls[0].request.url.endswith("?min_age=20")
Expand Down Expand Up @@ -106,18 +116,8 @@ def test_outdated_records_are_overwritten(self):
)
responses.add(responses.POST, self.server + "/batch", json={"responses": []})

backport_records(
event={
"server": self.server,
"safe_headers": True,
"backport_records_source_auth": self.auth,
"backport_records_source_bucket": self.source_bid,
"backport_records_source_collection": self.source_cid,
"backport_records_dest_bucket": self.dest_bid,
"backport_records_dest_collection": self.dest_cid,
},
context=None,
)
with mock.patch.dict("os.environ", {"SAFE_HEADERS": "true"}):
backport_records()

assert responses.calls[3].request.method == "POST"
posted_records = json.loads(responses.calls[3].request.body)
Expand Down Expand Up @@ -162,17 +162,7 @@ def test_nothing_to_do(self):
},
)

backport_records(
event={
"server": self.server,
"backport_records_source_auth": self.auth,
"backport_records_source_bucket": self.source_bid,
"backport_records_source_collection": self.source_cid,
"backport_records_dest_bucket": self.dest_bid,
"backport_records_dest_collection": self.dest_cid,
},
context=None,
)
backport_records()

assert len(responses.calls) == 3
assert responses.calls[0].request.method == "GET"
Expand Down Expand Up @@ -241,17 +231,7 @@ def test_pending_changes(self):
},
)

backport_records(
event={
"server": self.server,
"backport_records_source_auth": self.auth,
"backport_records_source_bucket": self.source_bid,
"backport_records_source_collection": self.source_cid,
"backport_records_dest_bucket": self.dest_bid,
"backport_records_dest_collection": self.dest_cid,
},
context=None,
)
backport_records()

assert len(responses.calls) == 6
assert responses.calls[0].request.method == "GET"
Expand Down Expand Up @@ -301,14 +281,14 @@ def test_pending_changes(self):
)
def test_correct_multiline_mappings(mapping_env, expected_calls):
with unittest.mock.patch("commands.backport_records.execute_backport") as mocked:
backport_records(
event={
"server": "http://server",
"backport_records_source_auth": "admin:admin",
"backport_records_mappings": mapping_env,
with mock.patch.dict(
"os.environ",
{
"BACKPORT_RECORDS_SOURCE_AUTH": "admin:admin",
"BACKPORT_RECORDS_MAPPINGS": mapping_env,
},
context=None,
)
):
backport_records()
for expected_params in expected_calls:
mocked.assert_any_call(
unittest.mock.ANY,
Expand All @@ -327,13 +307,13 @@ def test_correct_multiline_mappings(mapping_env, expected_calls):
],
)
def test_incorrect_multiline_mappings(mapping_env):
with unittest.mock.patch("commands.backport_records.execute_backport"):
with pytest.raises(expected_exception=ValueError, match="Invalid syntax"):
backport_records(
event={
"server": "http://server",
"backport_records_source_auth": "admin:admin",
"backport_records_mappings": mapping_env,
},
context=None,
)
with mock.patch.dict(
"os.environ",
{
"BACKPORT_RECORDS_SOURCE_AUTH": "admin:admin",
"BACKPORT_RECORDS_MAPPINGS": mapping_env,
},
):
with unittest.mock.patch("commands.backport_records.execute_backport"):
with pytest.raises(expected_exception=ValueError, match="Invalid syntax"):
backport_records()
5 changes: 3 additions & 2 deletions cronjobs/tests/commands/test_build_bundles.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,10 @@ def test_build_bundles(
mock_write_zip,
mock_write_json_mozlz4,
mock_sync_cloud_storage,
monkeypatch,
):
server_url = "http://testserver"
event = {"server": server_url}
monkeypatch.setenv("SERVER", server_url)

responses.add(
responses.GET,
Expand Down Expand Up @@ -282,7 +283,7 @@ def test_build_bundles(
status=404,
)

build_bundles(event, context={})
build_bundles()

assert mock_write_zip.call_count == 1 # only one for the attachments
calls = mock_write_zip.call_args_list
Expand Down
Loading
Loading