Skip to content

Commit

Permalink
Updates per PR comments
Browse files Browse the repository at this point in the history
Signed-off-by: Chris Helma <chelma+github@amazon.com>
  • Loading branch information
chelma committed Nov 15, 2024
1 parent 7676ccd commit 7003634
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -133,17 +133,16 @@ def scale(self, units: int, *args, **kwargs) -> CommandResult:
logger.info(f"Scaling RFS backfill by setting desired count to {units} instances")
return self.ecs_client.set_desired_count(units)

def archive(self, *args, archive_dir_path: str = None, **kwargs) -> CommandResult:
def archive(self, *args, archive_dir_path: str = None, archive_file_name: str = None, **kwargs) -> CommandResult:
logger.info("Confirming there are no currently in-progress workers")
status = self.ecs_client.get_instance_statuses()
if status.running > 0 or status.pending > 0 or status.desired > 0:
return CommandResult(False, RfsWorkersInProgress())

try:
backup_path = get_working_state_index_backup_path(archive_dir_path)
backup_path = get_working_state_index_backup_path(archive_dir_path, archive_file_name)
logger.info(f"Backing up working state index to {backup_path}")
documents = self.target_cluster.fetch_all_documents(WORKING_STATE_INDEX)
backup_working_state_index(documents, backup_path)
backup_working_state_index(self.target_cluster, WORKING_STATE_INDEX, backup_path)
logger.info(f"Working state index backed up successful")

logger.info("Cleaning up working state index on target cluster")
Expand Down Expand Up @@ -228,26 +227,42 @@ def _get_detailed_status(self) -> Optional[str]:

return "\n".join([f"Shards {key}: {value}" for key, value in values.items() if value is not None])

def get_working_state_index_backup_path(archive_dir_path: str = None) -> str:
def get_working_state_index_backup_path(archive_dir_path: str = None, archive_file_name: str = None) -> str:
shared_logs_dir = os.getenv("SHARED_LOGS_DIR_PATH", None)
if archive_dir_path:
backup_dir = archive_dir_path
elif shared_logs_dir is None:
backup_dir = "./working_state"
backup_dir = "./backfill_working_state"
else:
backup_dir = os.path.join(shared_logs_dir, "working_state")
backup_dir = os.path.join(shared_logs_dir, "backfill_working_state")

file_name = "working_state_backup.json"
if archive_file_name:
file_name = archive_file_name
else:
file_name = f"working_state_backup_{datetime.now().strftime('%Y%m%d%H%M%S')}.json"
return os.path.join(backup_dir, file_name)

def backup_working_state_index(working_state: Dict[str, Any], backup_path: str):
def backup_working_state_index(cluster: Cluster, index_name:str, backup_path: str):
# Ensure the backup directory exists
backup_dir = os.path.dirname(backup_path)
os.makedirs(backup_dir, exist_ok=True)

# Write the backup
with open(backup_path, "w") as f:
json.dump(working_state, f, indent=4)
# Backup the docs in the working state index as a JSON array containing batches of documents
with open(backup_path, 'w') as outfile:
outfile.write("[\n") # Start the JSON array
first_batch = True

for batch in cluster.fetch_all_documents(index_name=index_name):
if not first_batch:
outfile.write(",\n")
else:
first_batch = False

# Dump the batch of documents as an entry in the array
batch_json = json.dumps(batch, indent=4)
outfile.write(batch_json)

outfile.write("\n]") # Close the JSON array

def parse_query_response(query: dict, cluster: Cluster, label: str) -> Optional[int]:
try:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, Optional
from typing import Any, Dict, Generator, Optional
from enum import Enum
import json
import logging
Expand Down Expand Up @@ -200,9 +200,11 @@ def execute_benchmark_workload(self, workload: str,
logger.info(f"Executing command: {display_command}")
subprocess.run(command, shell=True)

def fetch_all_documents(self, index_name: str, batch_size: int = 100) -> Dict[str, Any]:
documents = {}
scroll_id = None
def fetch_all_documents(self, index_name: str, batch_size: int = 100) -> Generator[Dict[str, Any], None, None]:
"""
Generator that fetches all documents from the specified index in batches
"""

session = requests.Session()

# Step 1: Initiate the scroll
Expand All @@ -221,9 +223,9 @@ def fetch_all_documents(self, index_name: str, batch_size: int = 100) -> Dict[st
scroll_id = response_json.get('_scroll_id')
hits = response_json.get('hits', {}).get('hits', [])

# Add documents to result dictionary
for hit in hits:
documents[hit['_id']] = hit['_source']
# Yield the first batch of documents
if hits:
yield {hit['_id']: hit['_source'] for hit in hits}

# Step 2: Continue scrolling until no more documents
while scroll_id and hits:
Expand All @@ -241,9 +243,8 @@ def fetch_all_documents(self, index_name: str, batch_size: int = 100) -> Dict[st
scroll_id = response_json.get('_scroll_id')
hits = response_json.get('hits', {}).get('hits', [])

# Add documents to result dictionary
for hit in hits:
documents[hit['_id']] = hit['_source']
if hits:
yield {hit['_id']: hit['_source'] for hit in hits}

# Step 3: Cleanup the scroll if necessary
if scroll_id:
Expand All @@ -257,5 +258,3 @@ def fetch_all_documents(self, index_name: str, batch_size: int = 100) -> Dict[st
session=session,
raise_error=False
)

return documents
Original file line number Diff line number Diff line change
Expand Up @@ -315,15 +315,15 @@ def test_ecs_rfs_backfill_archive_as_expected(ecs_rfs_backfill, mocker, tmpdir):
)
mocker.patch.object(ECSService, 'get_instance_statuses', autospec=True, return_value=mocked_instance_status)

mocked_docs = {"id": {"key": "value"}}
mocked_docs = [{"id": {"key": "value"}}]
mocker.patch.object(Cluster, 'fetch_all_documents', autospec=True, return_value=mocked_docs)

mock_api = mocker.patch.object(Cluster, 'call_api', autospec=True, return_value=requests.Response())

result = ecs_rfs_backfill.archive(archive_dir_path=tmpdir.strpath)
result = ecs_rfs_backfill.archive(archive_dir_path=tmpdir.strpath, archive_file_name="backup.json")

assert result.success
expected_path = os.path.join(tmpdir.strpath, "working_state_backup.json")
expected_path = os.path.join(tmpdir.strpath, "backup.json")
assert result.value == expected_path
assert os.path.exists(expected_path)
with open(expected_path, "r") as f:
Expand All @@ -343,7 +343,7 @@ def test_ecs_rfs_backfill_archive_no_index_as_expected(ecs_rfs_backfill, mocker,
response_404.status_code = 404
mocker.patch.object(Cluster, 'fetch_all_documents', autospec=True, side_effect=requests.HTTPError(response=response_404, request=requests.Request()))

result = ecs_rfs_backfill.archive(archive_dir_path=tmpdir.strpath)
result = ecs_rfs_backfill.archive()

assert not result.success
assert isinstance(result.value, WorkingIndexDoesntExist)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -273,8 +273,8 @@ def test_valid_cluster_fetch_all_documents(requests_mock):
}
)
requests_mock.delete(f"{cluster.endpoint}/_search/scroll")
documents = cluster.fetch_all_documents(test_index, batch_size=batch_size)
assert documents == {"id_1": {"test1": True}, "id_2": {"test2": True}}
documents = [batch for batch in cluster.fetch_all_documents(test_index, batch_size=batch_size)]
assert documents == [{"id_1": {"test1": True}}, {"id_2": {"test2": True}}]


def test_connection_check_with_exception(mocker):
Expand Down

0 comments on commit 7003634

Please sign in to comment.