diff --git a/dk-installer.py b/dk-installer.py index c0ee246..b6e385b 100755 --- a/dk-installer.py +++ b/dk-installer.py @@ -89,6 +89,11 @@ # +def _get_tg_base_url(args): + protocol = "https" if args.ssl_cert_file and args.ssl_key_file else "http" + return f"{protocol}://localhost:{args.port}" + + def collect_images_digest(action, images, env=None): if images: action.run_cmd( @@ -1661,6 +1666,7 @@ def __init__(self): self.update_version = False self.update_analytics = False self.update_token = False + self.update_base_url = False super().__init__() def pre_execute(self, action, args): @@ -1720,12 +1726,19 @@ def pre_execute(self, action, args): self.update_token = "TG_JWT_HASHING_KEY" not in contents - if not any((self.update_version, self.update_analytics, self.update_token)): + self.update_base_url = "TG_UI_BASE_URL" not in contents + if self.update_base_url: + port_match = re.search(r"- (\d+):8501", contents) + port = port_match.group(1) if port_match else str(TESTGEN_DEFAULT_PORT) + protocol = "https" if "SSL_CERT_FILE" in contents else "http" + self._base_url = f"{protocol}://localhost:{port}" + + if not any((self.update_version, self.update_analytics, self.update_token, self.update_base_url)): CONSOLE.msg("No changes will be applied.") raise AbortAction def execute(self, action, args): - if not any((self.update_version, self.update_analytics, self.update_token)): + if not any((self.update_version, self.update_analytics, self.update_token, self.update_base_url)): raise SkipStep contents = action.get_compose_file_path(args).read_text() @@ -1755,6 +1768,11 @@ def execute(self, action, args): var = f"\n{match.group(1)}TG_JWT_HASHING_KEY: {str(base64.b64encode(random.randbytes(32)), 'ascii')}" contents = contents[0 : match.end()] + match.group(1) + var + contents[match.end() :] + if self.update_base_url: + match = re.search(r"^([ \t]+)TG_METADATA_DB_HOST:.*$", contents, flags=re.M) + var = f"\n{match.group(1)}TG_UI_BASE_URL: {self._base_url}" + contents = contents[0 : match.end()] + var + contents[match.end() :] + action.get_compose_file_path(args).write_text(contents) @@ -1787,10 +1805,9 @@ def pre_execute(self, action, args): def on_action_success(self, action, args): super().on_action_success(action, args) - protocol = "https" if args.ssl_cert_file and args.ssl_key_file else "http" cred_file_path = action.data_folder.joinpath(CREDENTIALS_FILE.format(args.prod)) with CONSOLE.tee(cred_file_path) as console_tee: - console_tee(f"User Interface: {protocol}://localhost:{args.port}") + console_tee(f"User Interface: {_get_tg_base_url(args)}") console_tee("CLI Access: docker compose exec engine bash") console_tee("") console_tee(f"Username: {self.username}") @@ -1849,6 +1866,7 @@ def get_compose_file_contents(self, action, args): TG_EXPORT_TO_OBSERVABILITY_VERIFY_SSL: no TG_INSTANCE_ID: {action.analytics.get_instance_id()} TG_ANALYTICS: {"yes" if args.send_analytics_data else "no"} + TG_UI_BASE_URL: {_get_tg_base_url(args)} {ssl_variables} services: diff --git a/tests/test_tg_install.py b/tests/test_tg_install.py index fbe91ed..8922ec4 100644 --- a/tests/test_tg_install.py +++ b/tests/test_tg_install.py @@ -86,3 +86,27 @@ def test_tg_create_compose_file_abort_args(arg_to_set, tg_install_action, stdout console_msg_mock.assert_any_msg_contains( "Both --ssl-cert-file and --ssl-key-file must be provided to use SSL certificates.", ) + + +@pytest.mark.integration +def test_tg_compose_contains_base_url(tg_install_action, start_cmd_mock, stdout_mock, compose_path): + tg_install_action.execute() + contents = compose_path.read_text() + assert "TG_UI_BASE_URL: http://localhost:8501" in contents + + +@pytest.mark.integration +def test_tg_compose_base_url_custom_port(tg_install_action, start_cmd_mock, stdout_mock, args_mock, compose_path): + args_mock.port = 9000 + tg_install_action.execute() + contents = compose_path.read_text() + assert "TG_UI_BASE_URL: http://localhost:9000" in contents + + +@pytest.mark.integration +def test_tg_compose_base_url_ssl(tg_install_action, start_cmd_mock, stdout_mock, args_mock, compose_path): + args_mock.ssl_cert_file = "/path/to/cert.crt" + args_mock.ssl_key_file = "/path/to/cert.key" + tg_install_action.execute() + contents = compose_path.read_text() + assert "TG_UI_BASE_URL: https://localhost:8501" in contents diff --git a/tests/test_tg_upgrade.py b/tests/test_tg_upgrade.py index 60828d6..be94d92 100644 --- a/tests/test_tg_upgrade.py +++ b/tests/test_tg_upgrade.py @@ -125,7 +125,9 @@ def test_tg_upgrade_abort( ): args_mock.skip_verify = False set_version_check_mock(version_check_mock, "1.0.0") - initial_compose_content = get_compose_content("TG_INSTANCE_ID: test-instance-id") + initial_compose_content = get_compose_content( + "TG_INSTANCE_ID: test-instance-id", "TG_UI_BASE_URL: http://localhost:8501" + ) compose_path.write_text(initial_compose_content) with pytest.raises(AbortAction): @@ -186,3 +188,43 @@ def test_tg_upgrade_disable_analytics( assert "TG_ANALYTICS: no" in compose_content assert "image: datakitchen/dataops-testgen:v2.14.5" in compose_content console_msg_mock.assert_any_msg_contains("Application is already up-to-date.") + + +@pytest.mark.integration +def test_tg_upgrade_adds_base_url( + tg_upgrade_action, + compose_path, + start_cmd_mock, + tg_upgrade_stdout_side_effect, + args_mock, + version_check_mock, +): + set_version_check_mock(version_check_mock, "1.0.0") + compose_path.write_text(get_compose_content("TG_INSTANCE_ID: test-instance-id")) + + tg_upgrade_action.execute(args_mock) + + compose_content = compose_path.read_text() + assert "TG_UI_BASE_URL: http://localhost:8501" in compose_content + + +@pytest.mark.integration +def test_tg_upgrade_preserves_existing_base_url( + tg_upgrade_action, + compose_path, + start_cmd_mock, + tg_upgrade_stdout_side_effect, + args_mock, + version_check_mock, +): + args_mock.skip_verify = True + set_version_check_mock(version_check_mock, "1.1.0") + compose_path.write_text( + get_compose_content("TG_INSTANCE_ID: test-instance-id", "TG_UI_BASE_URL: https://custom.example.com") + ) + + tg_upgrade_action.execute(args_mock) + + compose_content = compose_path.read_text() + assert "TG_UI_BASE_URL: https://custom.example.com" in compose_content + assert compose_content.count("TG_UI_BASE_URL") == 1