diff --git a/dashboard/agent.py b/dashboard/agent.py index 2542ecff5307..ed28ec5d0cd4 100644 --- a/dashboard/agent.py +++ b/dashboard/agent.py @@ -1,7 +1,6 @@ import argparse import asyncio import io -import json import logging import logging.handlers import os @@ -13,11 +12,10 @@ import ray._private.utils import ray.dashboard.consts as dashboard_consts import ray.dashboard.utils as dashboard_utils -import ray.experimental.internal_kv as internal_kv from ray._private.gcs_pubsub import GcsAioPublisher, GcsPublisher from ray._private.gcs_utils import GcsAioClient, GcsClient from ray._private.ray_logging import setup_component_logger -from ray.core.generated import agent_manager_pb2, agent_manager_pb2_grpc +from ray.core.generated import agent_manager_pb2, agent_manager_pb2_grpc, common_pb2 from ray.experimental.internal_kv import ( _initialize_internal_kv, _internal_kv_initialized, @@ -262,22 +260,20 @@ async def _check_parent(): # TODO: Use async version if performance is an issue # -1 should indicate that http server is not started. http_port = -1 if not self.http_server else self.http_server.http_port - internal_kv._internal_kv_put( - f"{dashboard_consts.DASHBOARD_AGENT_PORT_PREFIX}{self.node_id}", - json.dumps([http_port, self.grpc_port]), - namespace=ray_constants.KV_NAMESPACE_DASHBOARD, - ) # Register agent to agent manager. raylet_stub = agent_manager_pb2_grpc.AgentManagerServiceStub( self.aiogrpc_raylet_channel ) - await raylet_stub.RegisterAgent( agent_manager_pb2.RegisterAgentRequest( - agent_id=self.agent_id, - agent_port=self.grpc_port, - agent_ip_address=self.ip, + agent_info=common_pb2.AgentInfo( + id=self.agent_id, + pid=os.getpid(), + grpc_port=self.grpc_port, + http_port=http_port, + ip_address=self.ip, + ) ) ) diff --git a/dashboard/consts.py b/dashboard/consts.py index f2d27abd70c7..d30c5e1dd940 100644 --- a/dashboard/consts.py +++ b/dashboard/consts.py @@ -1,7 +1,6 @@ from ray._private.ray_constants import env_integer DASHBOARD_LOG_FILENAME = "dashboard.log" -DASHBOARD_AGENT_PORT_PREFIX = "DASHBOARD_AGENT_PORT_PREFIX:" DASHBOARD_AGENT_LOG_FILENAME = "dashboard_agent.log" DASHBOARD_AGENT_CHECK_PARENT_INTERVAL_SECONDS = 2 RAY_STATE_SERVER_MAX_HTTP_REQUEST_ENV_NAME = "RAY_STATE_SERVER_MAX_HTTP_REQUEST" diff --git a/dashboard/modules/node/node_head.py b/dashboard/modules/node/node_head.py index b0562992c8eb..24b538e4e794 100644 --- a/dashboard/modules/node/node_head.py +++ b/dashboard/modules/node/node_head.py @@ -1,5 +1,4 @@ import asyncio -import json import logging import re import time @@ -7,7 +6,6 @@ import aiohttp.web import ray._private.utils -import ray.dashboard.consts as dashboard_consts import ray.dashboard.optional_utils as dashboard_optional_utils import ray.dashboard.utils as dashboard_utils from ray._private import ray_constants @@ -131,10 +129,9 @@ async def _update_nodes(self): try: nodes = await self._get_nodes() - alive_node_ids = [] - alive_node_infos = [] node_id_to_ip = {} node_id_to_hostname = {} + agents = dict(DataSource.agents) for node in nodes.values(): node_id = node["nodeId"] ip = node["nodeManagerAddress"] @@ -150,20 +147,10 @@ async def _update_nodes(self): node_id_to_hostname[node_id] = hostname assert node["state"] in ["ALIVE", "DEAD"] if node["state"] == "ALIVE": - alive_node_ids.append(node_id) - alive_node_infos.append(node) - - agents = dict(DataSource.agents) - for node_id in alive_node_ids: - key = f"{dashboard_consts.DASHBOARD_AGENT_PORT_PREFIX}" f"{node_id}" - # TODO: Use async version if performance is an issue - agent_port = ray.experimental.internal_kv._internal_kv_get( - key, namespace=ray_constants.KV_NAMESPACE_DASHBOARD - ) - if agent_port: - agents[node_id] = json.loads(agent_port) - for node_id in agents.keys() - set(alive_node_ids): - agents.pop(node_id, None) + agents[node_id] = [ + node["agentInfo"]["httpPort"], + node["agentInfo"]["grpcPort"], + ] DataSource.node_id_to_ip.reset(node_id_to_ip) DataSource.node_id_to_hostname.reset(node_id_to_hostname) diff --git a/dashboard/modules/node/tests/test_node.py b/dashboard/modules/node/tests/test_node.py index ff300025df77..da3d0cb12a57 100644 --- a/dashboard/modules/node/tests/test_node.py +++ b/dashboard/modules/node/tests/test_node.py @@ -6,6 +6,7 @@ import traceback import random import pytest +import psutil import ray import threading from datetime import datetime, timedelta @@ -18,6 +19,7 @@ wait_for_condition, wait_until_succeeded_without_exception, ) +from ray._private.state import state logger = logging.getLogger(__name__) @@ -348,5 +350,33 @@ def verify(): wait_for_condition(verify, timeout=15) +# See detail: https://github.com/ray-project/ray/issues/24361 +@pytest.mark.skipif(sys.platform == "win32", reason="Flaky on Windows.") +def test_node_register_with_agent(ray_start_cluster_head): + def test_agent_port(pid, port): + p = psutil.Process(pid) + assert p.cmdline()[2].endswith("dashboard/agent.py") + + for c in p.connections(): + if c.status == psutil.CONN_LISTEN and c.laddr.port == port: + return + assert False + + def test_agent_process(pid): + p = psutil.Process(pid) + assert p.cmdline()[2].endswith("dashboard/agent.py") + + for node_info in state.node_table(): + agent_info = node_info["AgentInfo"] + assert agent_info["IpAddress"] == node_info["NodeManagerAddress"] + test_agent_port(agent_info["Pid"], agent_info["GrpcPort"]) + if agent_info["HttpPort"] >= 0: + test_agent_port(agent_info["Pid"], agent_info["HttpPort"]) + else: + # Port conflicts may be caused that the previous + # test did not kill the agent cleanly + assert agent_info["HttpPort"] == -1 + + if __name__ == "__main__": sys.exit(pytest.main(["-v", __file__])) diff --git a/dashboard/tests/test_dashboard.py b/dashboard/tests/test_dashboard.py index d8ca31714fcf..7c0123f9564c 100644 --- a/dashboard/tests/test_dashboard.py +++ b/dashboard/tests/test_dashboard.py @@ -107,7 +107,6 @@ def test_basic(ray_start_with_dashboard): """Dashboard test that starts a Ray cluster with a dashboard server running, then hits the dashboard API and asserts that it receives sensible data.""" address_info = ray_start_with_dashboard - node_id = address_info["node_id"] gcs_client = make_gcs_client(address_info) ray.experimental.internal_kv._initialize_internal_kv(gcs_client) @@ -143,11 +142,6 @@ def test_basic(ray_start_with_dashboard): namespace=ray_constants.KV_NAMESPACE_DASHBOARD, ) assert dashboard_rpc_address is not None - key = f"{dashboard_consts.DASHBOARD_AGENT_PORT_PREFIX}{node_id}" - agent_ports = ray.experimental.internal_kv._internal_kv_get( - key, namespace=ray_constants.KV_NAMESPACE_DASHBOARD - ) - assert agent_ports is not None def test_raylet_and_agent_share_fate(shutdown_only): @@ -792,7 +786,6 @@ def test_dashboard_port_conflict(ray_start_with_dashboard): ) def test_gcs_check_alive(fast_gcs_failure_detection, ray_start_with_dashboard): assert wait_until_server_available(ray_start_with_dashboard["webui_url"]) is True - all_processes = ray._private.worker._global_node.all_processes dashboard_info = all_processes[ray_constants.PROCESS_TYPE_DASHBOARD][0] dashboard_proc = psutil.Process(dashboard_info.process.pid) diff --git a/python/ray/_private/state.py b/python/ray/_private/state.py index 76ff68867121..70705c38d74e 100644 --- a/python/ray/_private/state.py +++ b/python/ray/_private/state.py @@ -164,6 +164,12 @@ def node_table(self): "RayletSocketName": item.raylet_socket_name, "MetricsExportPort": item.metrics_export_port, "NodeName": item.node_name, + "AgentInfo": { + "IpAddress": item.agent_info.ip_address, + "GrpcPort": item.agent_info.grpc_port, + "HttpPort": item.agent_info.http_port, + "Pid": item.agent_info.pid, + }, } node_info["alive"] = node_info["Alive"] node_info["Resources"] = ( diff --git a/python/ray/tests/BUILD b/python/ray/tests/BUILD index 2f9172a04ced..2b5d76cde576 100644 --- a/python/ray/tests/BUILD +++ b/python/ray/tests/BUILD @@ -68,7 +68,8 @@ py_test_module_list( "test_healthcheck.py", "test_kill_raylet_signal_log.py", "test_memstat.py", - "test_protobuf_compatibility.py" + "test_protobuf_compatibility.py", + "test_scheduling_performance.py" ], size = "medium", tags = ["exclusive", "medium_size_python_tests_a_to_j", "team:core"], @@ -120,7 +121,6 @@ py_test_module_list( "test_multi_node_2.py", "test_multinode_failures.py", "test_multinode_failures_2.py", - "test_multiprocessing.py", "test_object_assign_owner.py", "test_placement_group.py", "test_placement_group_2.py", @@ -138,7 +138,6 @@ py_test_module_list( "test_runtime_env_fork_process.py", "test_serialization.py", "test_shuffle.py", - "test_state_api.py", "test_state_api_log.py", "test_state_api_summary.py", "test_stress.py", @@ -186,7 +185,6 @@ py_test_module_list( "test_cross_language.py", "test_environ.py", "test_raylet_output.py", - "test_scheduling_performance.py", "test_get_or_create_actor.py", ], size = "small", @@ -265,6 +263,7 @@ py_test_module_list( "test_chaos.py", "test_reference_counting_2.py", "test_exit_observability.py", + "test_state_api.py", "test_usage_stats.py", ], size = "large", @@ -295,6 +294,7 @@ py_test_module_list( "test_placement_group_mini_integration.py", "test_scheduling_2.py", "test_multi_node_3.py", + "test_multiprocessing.py", ], size = "large", tags = ["exclusive", "large_size_python_tests_shard_1", "team:core"], diff --git a/python/ray/tests/conftest.py b/python/ray/tests/conftest.py index 2a9edfb4630d..50d21de73673 100644 --- a/python/ray/tests/conftest.py +++ b/python/ray/tests/conftest.py @@ -15,6 +15,7 @@ from tempfile import gettempdir from typing import List, Tuple from unittest import mock +import signal import pytest @@ -204,10 +205,19 @@ def _ray_start(**kwargs): init_kwargs.update(kwargs) # Start the Ray processes. address_info = ray.init("local", **init_kwargs) + agent_pids = [] + for node in ray.nodes(): + agent_pids.append(int(node["AgentInfo"]["Pid"])) yield address_info # The code after the yield will run as teardown code. ray.shutdown() + # Make sure the agent process is dead. + for pid in agent_pids: + try: + os.kill(pid, signal.SIGKILL) + except Exception: + pass # Delete the cluster address just in case. ray._private.utils.reset_ray_address() diff --git a/python/ray/tests/test_cli.py b/python/ray/tests/test_cli.py index fb5f6aee0903..0e5696671ec8 100644 --- a/python/ray/tests/test_cli.py +++ b/python/ray/tests/test_cli.py @@ -834,8 +834,9 @@ def output_ready(): @pytest.mark.xfail(cluster_not_supported, reason="cluster not supported on Windows") def test_ray_status_multinode(ray_start_cluster): + NODE_NUMBER = 4 cluster = ray_start_cluster - for _ in range(4): + for _ in range(NODE_NUMBER): cluster.add_node(num_cpus=2) runner = CliRunner() @@ -850,8 +851,12 @@ def output_ready(): wait_for_condition(output_ready) - result = runner.invoke(scripts.status, []) - _check_output_via_pattern("test_ray_status_multinode.txt", result) + def check_result(): + result = runner.invoke(scripts.status, []) + _check_output_via_pattern("test_ray_status_multinode.txt", result) + return True + + wait_for_condition(check_result) @pytest.mark.skipif( diff --git a/python/ray/tests/test_ray_shutdown.py b/python/ray/tests/test_ray_shutdown.py index 8ce6552fceba..0b34f6573f1a 100644 --- a/python/ray/tests/test_ray_shutdown.py +++ b/python/ray/tests/test_ray_shutdown.py @@ -11,17 +11,25 @@ def get_all_ray_worker_processes(): - processes = [ - p.info["cmdline"] for p in psutil.process_iter(attrs=["pid", "name", "cmdline"]) - ] + processes = psutil.process_iter(attrs=["pid", "name", "cmdline"]) result = [] for p in processes: - if p is not None and len(p) > 0 and "ray::" in p[0]: + cmd_line = p.info["cmdline"] + if cmd_line is not None and len(cmd_line) > 0 and "ray::" in cmd_line[0]: result.append(p) return result +def kill_all_ray_worker_process(): + ray_process = get_all_ray_worker_processes() + for p in ray_process: + try: + p.kill() + except Exception: + pass + + @pytest.fixture def short_gcs_publish_timeout(monkeypatch): monkeypatch.setenv("RAY_MAX_GCS_PUBLISH_RETRIES", "3") @@ -31,6 +39,10 @@ def short_gcs_publish_timeout(monkeypatch): @pytest.mark.skipif(platform.system() == "Windows", reason="Hang on Windows.") def test_ray_shutdown(short_gcs_publish_timeout, shutdown_only): """Make sure all ray workers are shutdown when driver is done.""" + # Avoiding the previous test doesn't kill the relevant process, + # thus making the current test fail. + kill_all_ray_worker_process() + ray.init() @ray.remote @@ -51,6 +63,10 @@ def f(): @pytest.mark.skipif(platform.system() == "Windows", reason="Hang on Windows.") def test_driver_dead(short_gcs_publish_timeout, shutdown_only): """Make sure all ray workers are shutdown when driver is killed.""" + # Avoiding the previous test doesn't kill the relevant process, + # thus making the current test fail. + kill_all_ray_worker_process() + driver = """ import ray ray.init(_system_config={"gcs_rpc_server_reconnect_timeout_s": 1}) @@ -80,6 +96,10 @@ def f(): @pytest.mark.skipif(platform.system() == "Windows", reason="Hang on Windows.") def test_node_killed(short_gcs_publish_timeout, ray_start_cluster): """Make sure all ray workers when nodes are dead.""" + # Avoiding the previous test doesn't kill the relevant process, + # thus making the current test fail. + kill_all_ray_worker_process() + cluster = ray_start_cluster # head node. cluster.add_node( @@ -112,6 +132,10 @@ def f(): @pytest.mark.skipif(platform.system() == "Windows", reason="Hang on Windows.") def test_head_node_down(short_gcs_publish_timeout, ray_start_cluster): """Make sure all ray workers when head node is dead.""" + # Avoiding the previous test doesn't kill the relevant process, + # thus making the current test fail. + kill_all_ray_worker_process() + cluster = ray_start_cluster # head node. head = cluster.add_node( diff --git a/python/ray/tests/test_state_api.py b/python/ray/tests/test_state_api.py index 128c201ef8c7..30a8a938f5de 100644 --- a/python/ray/tests/test_state_api.py +++ b/python/ray/tests/test_state_api.py @@ -10,7 +10,6 @@ from click.testing import CliRunner import ray -import ray.dashboard.consts as dashboard_consts import ray._private.state as global_state import ray._private.ray_constants as ray_constants from ray._private.test_utils import ( @@ -1160,16 +1159,8 @@ async def test_state_data_source_client(ray_start_cluster): wait_for_condition(lambda: len(ray.nodes()) == 2) for node in ray.nodes(): node_id = node["NodeID"] - key = f"{dashboard_consts.DASHBOARD_AGENT_PORT_PREFIX}{node_id}" - - def get_port(): - return ray.experimental.internal_kv._internal_kv_get( - key, namespace=ray_constants.KV_NAMESPACE_DASHBOARD - ) - - wait_for_condition(lambda: get_port() is not None) # The second index is the gRPC port - port = json.loads(get_port())[1] + port = node["AgentInfo"]["GrpcPort"] ip = node["NodeManagerAddress"] client.register_agent_client(node_id, ip, port) result = await client.get_runtime_envs_info(node_id) @@ -1371,16 +1362,8 @@ async def verify(): """ for node in ray.nodes(): node_id = node["NodeID"] - key = f"{dashboard_consts.DASHBOARD_AGENT_PORT_PREFIX}{node_id}" - - def get_port(): - return ray.experimental.internal_kv._internal_kv_get( - key, namespace=ray_constants.KV_NAMESPACE_DASHBOARD - ) - - wait_for_condition(lambda: get_port() is not None) # The second index is the gRPC port - port = json.loads(get_port())[1] + port = node["AgentInfo"]["GrpcPort"] ip = node["NodeManagerAddress"] client.register_agent_client(node_id, ip, port) @@ -1493,8 +1476,13 @@ def verify_output(cmd, args: List[str], necessary_substrings: List[str]): ) ) # Test get workers by id + + # Still need a `wait_for_condition`, + # because the worker obtained through the api server will not filter the driver, + # but `global_state.workers` will filter the driver. + wait_for_condition(lambda: len(global_state.workers()) > 0) workers = global_state.workers() - assert len(workers) > 0 + worker_id = list(workers.keys())[0] wait_for_condition( lambda: verify_output(cli_get, ["workers", worker_id], ["worker_id", worker_id]) diff --git a/src/ray/common/ray_config_def.h b/src/ray/common/ray_config_def.h index 2c37bed4ec4f..6492e171183a 100644 --- a/src/ray/common/ray_config_def.h +++ b/src/ray/common/ray_config_def.h @@ -76,6 +76,9 @@ RAY_CONFIG(uint64_t, raylet_report_resources_period_milliseconds, 100) /// The duration between raylet check memory pressure and send gc request RAY_CONFIG(uint64_t, raylet_check_gc_period_milliseconds, 100) +/// If the raylet fails to get agent info, we will retry after this interval. +RAY_CONFIG(uint64_t, raylet_get_agent_info_interval_ms, 1) + /// For a raylet, if the last resource report was sent more than this many /// report periods ago, then a warning will be logged that the report /// handler is drifting. diff --git a/src/ray/protobuf/BUILD b/src/ray/protobuf/BUILD index 9dc2345f8f7a..e1ec19f8817d 100644 --- a/src/ray/protobuf/BUILD +++ b/src/ray/protobuf/BUILD @@ -243,7 +243,7 @@ python_grpc_compile( proto_library( name = "agent_manager_proto", srcs = ["agent_manager.proto"], - deps = [], + deps = [":common_proto"], ) python_grpc_compile( diff --git a/src/ray/protobuf/agent_manager.proto b/src/ray/protobuf/agent_manager.proto index bf438ece0464..5d6f08e657fd 100644 --- a/src/ray/protobuf/agent_manager.proto +++ b/src/ray/protobuf/agent_manager.proto @@ -17,6 +17,8 @@ option cc_enable_arenas = true; package ray.rpc; +import "src/ray/protobuf/common.proto"; + enum AgentRpcStatus { // OK. AGENT_RPC_STATUS_OK = 0; @@ -25,9 +27,7 @@ enum AgentRpcStatus { } message RegisterAgentRequest { - int32 agent_id = 1; - int32 agent_port = 2; - string agent_ip_address = 3; + AgentInfo agent_info = 1; } message RegisterAgentReply { diff --git a/src/ray/protobuf/common.proto b/src/ray/protobuf/common.proto index 8123fa1b36bb..d3b1eaf07b0e 100644 --- a/src/ray/protobuf/common.proto +++ b/src/ray/protobuf/common.proto @@ -677,3 +677,17 @@ message NamedActorInfo { string ray_namespace = 1; string name = 2; } + +// Info about a agent process. +message AgentInfo { + // The agent id. + int32 id = 1; + // The agent process pid. + int64 pid = 2; + // IP address of the agent process. + string ip_address = 3; + // The GRPC port number of the agent process. + int32 grpc_port = 4; + // The http port number of the agent process. + int32 http_port = 5; +} \ No newline at end of file diff --git a/src/ray/protobuf/gcs.proto b/src/ray/protobuf/gcs.proto index 66251bc533bd..5c1ad0bcc54f 100644 --- a/src/ray/protobuf/gcs.proto +++ b/src/ray/protobuf/gcs.proto @@ -247,6 +247,8 @@ message GcsNodeInfo { // The user-provided identifier or name for this node. string node_name = 12; + // The information of the agent process. + AgentInfo agent_info = 13; } message HeartbeatTableData { diff --git a/src/ray/raylet/agent_manager.cc b/src/ray/raylet/agent_manager.cc index 08e3256e520e..8ee8440dcdb8 100644 --- a/src/ray/raylet/agent_manager.cc +++ b/src/ray/raylet/agent_manager.cc @@ -29,19 +29,20 @@ namespace raylet { void AgentManager::HandleRegisterAgent(const rpc::RegisterAgentRequest &request, rpc::RegisterAgentReply *reply, rpc::SendReplyCallback send_reply_callback) { - reported_agent_ip_address_ = request.agent_ip_address(); - reported_agent_port_ = request.agent_port(); - reported_agent_id_ = request.agent_id(); + reported_agent_info_.CopyFrom(request.agent_info()); // TODO(SongGuyang): We should remove this after we find better port resolution. - // Note: `agent_port_` should be 0 if the grpc port of agent is in conflict. - if (reported_agent_port_ != 0) { + // Note: `reported_agent_info_.grpc_port()` should be 0 if the grpc port of agent is in + // conflict. + if (reported_agent_info_.grpc_port() != 0) { runtime_env_agent_client_ = runtime_env_agent_client_factory_( - reported_agent_ip_address_, reported_agent_port_); - RAY_LOG(INFO) << "HandleRegisterAgent, ip: " << reported_agent_ip_address_ - << ", port: " << reported_agent_port_ << ", id: " << reported_agent_id_; + reported_agent_info_.ip_address(), reported_agent_info_.grpc_port()); + RAY_LOG(INFO) << "HandleRegisterAgent, ip: " << reported_agent_info_.ip_address() + << ", port: " << reported_agent_info_.grpc_port() + << ", id: " << reported_agent_info_.id(); } else { RAY_LOG(WARNING) << "The GRPC port of the Ray agent is invalid (0), ip: " - << reported_agent_ip_address_ << ", id: " << reported_agent_id_ + << reported_agent_info_.ip_address() + << ", id: " << reported_agent_info_.id() << ". The agent client in the raylet has been disabled."; disable_agent_client_ = true; } @@ -56,16 +57,19 @@ void AgentManager::StartAgent() { return; } - // Create a non-zero random agent_id to pass to the child process + // Create a non-zero random agent_id_ to pass to the child process // We cannot use pid an id because os.getpid() from the python process is not // reliable when using a launcher. // See https://github.com/ray-project/ray/issues/24361 and Python issue // https://github.com/python/cpython/issues/83086 - int agent_id = 0; - while (agent_id == 0) { - agent_id = rand(); + agent_id_ = 0; + while (agent_id_ == 0) { + agent_id_ = rand(); } - const std::string agent_id_str = std::to_string(agent_id); + // Make sure reported_agent_info_.id() not equal + // `agent_id_` before the agent finished register. + reported_agent_info_.set_id(0); + const std::string agent_id_str = std::to_string(agent_id_); std::vector argv; for (const std::string &arg : options_.agent_commands) { argv.push_back(arg.c_str()); @@ -104,21 +108,22 @@ void AgentManager::StartAgent() { << ec.message(); } - std::thread monitor_thread([this, child, agent_id]() mutable { + std::thread monitor_thread([this, child]() mutable { SetThreadName("agent.monitor"); - RAY_LOG(INFO) << "Monitor agent process with id " << agent_id << ", register timeout " + RAY_LOG(INFO) << "Monitor agent process with id " << child.GetId() + << ", register timeout " << RayConfig::instance().agent_register_timeout_ms() << "ms."; auto timer = delay_executor_( - [this, child, agent_id]() mutable { - if (reported_agent_id_ != agent_id) { - if (reported_agent_id_ == 0) { - RAY_LOG(WARNING) << "Agent process expected id " << agent_id + [this, child]() mutable { + if (!IsAgentRegistered()) { + if (reported_agent_info_.id() == 0) { + RAY_LOG(WARNING) << "Agent process expected id " << agent_id_ << " timed out before registering. ip " - << reported_agent_ip_address_ << ", id " - << reported_agent_id_; + << reported_agent_info_.ip_address() << ", id " + << reported_agent_info_.id(); } else { - RAY_LOG(WARNING) << "Agent process expected id " << agent_id - << " but got id " << reported_agent_id_ + RAY_LOG(WARNING) << "Agent process expected id " << agent_id_ + << " but got id " << reported_agent_info_.id() << ", this is a fatal error"; } child.Kill(); @@ -128,9 +133,9 @@ void AgentManager::StartAgent() { int exit_code = child.Wait(); timer->cancel(); - RAY_LOG(WARNING) << "Agent process with id " << agent_id << " exited, return value " - << exit_code << ". ip " << reported_agent_ip_address_ << ". id " - << reported_agent_id_; + RAY_LOG(WARNING) << "Agent process with id " << agent_id_ << " exited, return value " + << exit_code << ". ip " << reported_agent_info_.ip_address() + << ". id " << reported_agent_info_.id(); RAY_LOG(ERROR) << "The raylet exited immediately because the Ray agent failed. " "The raylet fate shares with the agent. This can happen because the " @@ -303,5 +308,15 @@ void AgentManager::DeleteRuntimeEnvIfPossible( }); } +const ray::Status AgentManager::TryToGetAgentInfo(rpc::AgentInfo *agent_info) const { + if (IsAgentRegistered()) { + *agent_info = reported_agent_info_; + return ray::Status::OK(); + } else { + std::string err_msg = "The agent has not finished register yet."; + return ray::Status::Invalid(err_msg); + } +} + } // namespace raylet } // namespace ray diff --git a/src/ray/raylet/agent_manager.h b/src/ray/raylet/agent_manager.h index ef460b96c9f6..6c453da85187 100644 --- a/src/ray/raylet/agent_manager.h +++ b/src/ray/raylet/agent_manager.h @@ -23,6 +23,7 @@ #include "ray/rpc/agent_manager/agent_manager_server.h" #include "ray/rpc/runtime_env/runtime_env_client.h" #include "ray/util/process.h" +#include "src/ray/protobuf/gcs.pb.h" namespace ray { namespace raylet { @@ -88,17 +89,28 @@ class AgentManager : public rpc::AgentManagerServiceHandler { virtual void DeleteRuntimeEnvIfPossible(const std::string &serialized_runtime_env, DeleteRuntimeEnvIfPossibleCallback callback); + /// Try to Get the information about the agent process. + /// + /// \param[out] agent_info The information of the agent process. + /// \return Status, if successful will return `ray::Status::OK`, + /// otherwise will return `ray::Status::Invalid`. + const ray::Status TryToGetAgentInfo(rpc::AgentInfo *agent_info) const; + private: void StartAgent(); + const bool IsAgentRegistered() const { return reported_agent_info_.id() == agent_id_; } + private: Options options_; - pid_t reported_agent_id_ = 0; - int reported_agent_port_ = 0; + /// we need to make sure `agent_id_` and `reported_agent_info_.id()` are not equal + /// until the agent process is finished registering, the initial value of + /// `reported_agent_info_.id()` is 0, so I set the initial value of `agent_id_` is -1 + int agent_id_ = -1; + rpc::AgentInfo reported_agent_info_; /// Whether or not we intend to start the agent. This is false if we /// are missing Ray Dashboard dependencies, for example. bool should_start_agent_ = true; - std::string reported_agent_ip_address_; DelayExecutorFn delay_executor_; RuntimeEnvAgentClientFactoryFn runtime_env_agent_client_factory_; std::shared_ptr runtime_env_agent_client_; diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index 1c4950f873f6..342262f9a314 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -2860,6 +2860,10 @@ void NodeManager::PublishInfeasibleTaskError(const RayTask &task) const { } } +const ray::Status NodeManager::TryToGetAgentInfo(rpc::AgentInfo *agent_info) const { + return agent_manager_->TryToGetAgentInfo(agent_info); +} + } // namespace raylet } // namespace ray diff --git a/src/ray/raylet/node_manager.h b/src/ray/raylet/node_manager.h index 7af5f7c46a8f..4d63fc597bb3 100644 --- a/src/ray/raylet/node_manager.h +++ b/src/ray/raylet/node_manager.h @@ -237,6 +237,13 @@ class NodeManager : public rpc::NodeManagerServiceHandler, int64_t limit, const std::function &on_all_replied); + /// Try to Get the information about the agent process. + /// + /// \param[out] agent_info The information of the agent process. + /// \return Status, if successful will return `ray::Status::OK`, + /// otherwise will return `ray::Status::Invalid`. + const ray::Status TryToGetAgentInfo(rpc::AgentInfo *agent_info) const; + private: /// Methods for handling nodes. diff --git a/src/ray/raylet/raylet.cc b/src/ray/raylet/raylet.cc index e7171661bd70..6c8c258ddeb5 100644 --- a/src/ray/raylet/raylet.cc +++ b/src/ray/raylet/raylet.cc @@ -97,7 +97,6 @@ Raylet::~Raylet() {} void Raylet::Start() { RAY_CHECK_OK(RegisterGcs()); - // Start listening for clients. DoAccept(); } @@ -109,6 +108,21 @@ void Raylet::Stop() { } ray::Status Raylet::RegisterGcs() { + rpc::AgentInfo agent_info; + auto status = node_manager_.TryToGetAgentInfo(&agent_info); + if (status.ok()) { + self_node_info_.mutable_agent_info()->CopyFrom(agent_info); + } else { + // Because current function and `AgentManager::HandleRegisterAgent` + // will be invoke in same thread, so we need post current function + // into main_service_ after interval milliseconds. + std::this_thread::sleep_for(std::chrono::milliseconds( + RayConfig::instance().raylet_get_agent_info_interval_ms())); + main_service_.post([this]() { RAY_CHECK_OK(RegisterGcs()); }, + "Raylet.TryToGetAgentInfoAndRegisterGcs"); + return Status::OK(); + } + auto register_callback = [this](const Status &status) { RAY_CHECK_OK(status); RAY_LOG(INFO) << "Raylet of id, " << self_node_id_ diff --git a/src/ray/raylet/raylet.h b/src/ray/raylet/raylet.h index d2618d1ac32d..b5791d9288c6 100644 --- a/src/ray/raylet/raylet.h +++ b/src/ray/raylet/raylet.h @@ -68,7 +68,7 @@ class Raylet { NodeID GetNodeId() const { return self_node_id_; } private: - /// Register GCS client. + /// Try to get agent info, after its success, register the current node to GCS. ray::Status RegisterGcs(); /// Accept a client connection. diff --git a/src/ray/raylet/worker_pool_test.cc b/src/ray/raylet/worker_pool_test.cc index a6c3a37141d3..e3f1893dbc45 100644 --- a/src/ray/raylet/worker_pool_test.cc +++ b/src/ray/raylet/worker_pool_test.cc @@ -503,7 +503,8 @@ class WorkerPoolTest : public ::testing::Test { false); rpc::RegisterAgentRequest request; // Set agent port to a nonzero value to avoid invalid agent client. - request.set_agent_port(12345); + request.mutable_agent_info()->set_grpc_port(12345); + request.mutable_agent_info()->set_http_port(54321); rpc::RegisterAgentReply reply; auto send_reply_callback = [](ray::Status status, std::function f1, std::function f2) {};