Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DeviceAPI::isAgent() DeviceAPI::isState() #1116

Merged
merged 8 commits into from
Sep 29, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions include/flamegpu/runtime/DeviceAPI.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,26 @@ class ReadOnlyDeviceAPI {
#endif
return blockIdx.x * blockDim.x + threadIdx.x;
}
/**
* When passed an agent name, returns a boolean to confirm whether it matches the name of the current agent
*
* This function may be useful if an agent function is shared between multiple agents
*
* @note The performance of this function is unlikely to be cheap unless used as part of an RTC agent function.
*/
__forceinline__ __device__ bool isAgent(const char* agent_name) {
return detail::curve::DeviceCurve::isAgent(agent_name);
}
/**
* When passed an agent state, returns a boolean to confirm whether it matches the name of the agent input state of the current agent function
*
* This function may be useful if an agent function is shared between multiple agent states
*
* @note The performance of this function is unlikely to be cheap unless used as part of an RTC agent function (whereby it can be processed at compile time).
*/
__forceinline__ __device__ bool isState(const char* agent_state) {
return detail::curve::DeviceCurve::isState(agent_state);
}
};

/** @brief A flame gpu api class for the device runtime only
Expand Down Expand Up @@ -336,6 +356,26 @@ class DeviceAPI {
#endif
return blockIdx.x * blockDim.x + threadIdx.x;
}
/**
* When passed an agent name, returns a boolean to confirm whether it matches the name of the current agent
*
* This function may be useful if an agent function is shared between multiple agents
*
* @note The performance of this function is unlikely to be cheap unless used as part of an RTC agent function.
*/
__forceinline__ __device__ bool isAgent(const char* agent_name) {
return detail::curve::DeviceCurve::isAgent(agent_name);
}
/**
* When passed an agent state, returns a boolean to confirm whether it matches the name of the agent input state of the current agent function
*
* This function may be useful if an agent function is shared between multiple agent states
*
* @note The performance of this function is unlikely to be cheap unless used as part of an RTC agent function (whereby it can be processed at compile time).
*/
__forceinline__ __device__ bool isState(const char* agent_state) {
return detail::curve::DeviceCurve::isState(agent_state);
}

/**
* Provides access to message read functionality inside agent functions
Expand Down
26 changes: 26 additions & 0 deletions include/flamegpu/runtime/detail/curve/DeviceCurve.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,23 @@ class DeviceCurve {
*/
template<typename T, unsigned int I = 1, unsigned int J = 1, unsigned int K = 1, unsigned int W = 1, unsigned int M>
__device__ __forceinline__ static char *getEnvironmentMacroProperty(const char(&name)[M]);

/**
* When passed an agent name, returns a boolean to confirm whether it matches the name of the current agent
*
* This function may be useful if an agent function is shared between multiple agents
*
* @note The performance of this function is unlikely to be cheap unless used as part of an RTC agent function.
*/
__device__ __forceinline__ static bool isAgent(const char* agent_name);
/**
* When passed an agent state, returns a boolean to confirm whether it matches the name of the agent input state of the current agent function
*
* This function may be useful if an agent function is shared between multiple agent states
*
* @note The performance of this function is unlikely to be cheap unless used as part of an RTC agent function (whereby it can be processed at compile time).
*/
__device__ __forceinline__ static bool isState(const char* agent_state);
};

////
Expand Down Expand Up @@ -398,6 +415,15 @@ template<typename T, unsigned int I, unsigned int J, unsigned int K, unsigned in
__device__ __forceinline__ char* DeviceCurve::getEnvironmentMacroProperty(const char(&name)[M]) {
return getVariablePtr<T, I*J*K*W>(name, Curve::variableHash("_macro_environment"), 0);
}

__device__ __forceinline__ bool DeviceCurve::isAgent(const char* agent_name) {
return strcmp(agent_name, "todo") == 0; // @todo
}
__device__ __forceinline__ bool DeviceCurve::isState(const char* agent_state) {
return strcmp(agent_state, "todo") == 0; // @todo
}


} // namespace curve
} // namespace detail
} // namespace flamegpu
Expand Down
18 changes: 18 additions & 0 deletions include/flamegpu/runtime/detail/curve/curve_rtc.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,16 @@ class CurveRTCHost {
* @throws exception::UnknownInternalError If the specified property is not registered
*/
void unregisterEnvMacroProperty(const char* propertyName);
/**
* Register the name of the agent and it's state of the agent function
*
* Used by ReadOnlyDeviceAPI::isAgent() and ReadOnlyDeviceAPI::isState()
*
* @param agentName Name of the agent
* @param agentState Name of the agent's state
* @throws exception::UnknownInternalError If the agent has already been registered
*/
void registerAgent(const std::string &agentName, const std::string &agentState);
/**
* Set the filename tagged in the file (goes into a #line statement)
* @param filename Name to be used for the file in compile errors
Expand Down Expand Up @@ -378,6 +388,14 @@ class CurveRTCHost {
* <name, RTCVariableProperties>
*/
std::map<std::string, RTCEnvMacroPropertyProperties> RTCEnvMacroProperties;
/**
* Agent name for ReadOnlyDeviceAPI::isAgent()
*/
std::string agentName;
/**
* Agent name for ReadOnlyDeviceAPI::isState()
*/
std::string agentState;
};

} // namespace curve
Expand Down
28 changes: 28 additions & 0 deletions src/flamegpu/runtime/detail/curve/curve_rtc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ class DeviceCurve {
template <typename T, unsigned int N, unsigned int M>
__device__ __forceinline__ static void setNewAgentArrayVariable(const char(&name)[M], T variable, unsigned int variable_index, unsigned int array_index);

__device__ __forceinline__ static bool isAgent(const char* agent_name);
__device__ __forceinline__ static bool isState(const char* agent_state);
};

template <typename T, unsigned int N>
Expand Down Expand Up @@ -170,6 +172,22 @@ __device__ __forceinline__ void DeviceCurve::setNewAgentArrayVariable(const char
$DYNAMIC_SETNEWAGENTARRAYVARIABLE_IMPL
}

// https://stackoverflow.com/a/34873763/1646387
__device__ __forceinline__ int strcmp(const char *s1, const char *s2) {
const unsigned char *p1 = (const unsigned char *)s1;
const unsigned char *p2 = (const unsigned char *)s2;

while(*p1 && *p1 == *p2) ++p1, ++p2;

return (*p1 > *p2) - (*p2 > *p1);
}
__device__ __forceinline__ bool DeviceCurve::isAgent(const char* agent_name) {
return strcmp(agent_name, "$DYNAMIC_AGENT_NAME") == 0;
}
__device__ __forceinline__ bool DeviceCurve::isState(const char* agent_state) {
return strcmp(agent_state, "$DYNAMIC_AGENT_STATE") == 0;
}

} // namespace curve
} // namespace detail
} // namespace flamegpu
Expand Down Expand Up @@ -331,6 +349,14 @@ void CurveRTCHost::registerEnvVariable(const char* propertyName, ptrdiff_t offse
THROW exception::UnknownInternalError("Environment property with name '%s' is already registered, in CurveRTCHost::registerEnvVariable()", propertyName);
}
}
void CurveRTCHost::registerAgent(const std::string &_agentName, const std::string &_agentState) {
if (this->agentName.empty()) {
this->agentName = _agentName;
this->agentState = _agentState;
} else {
THROW exception::UnknownInternalError("Agent is already registered with name '%s' and state '%s', in CurveRTCHost::registerAgent()", this->agentName.c_str(), this->agentState.c_str());
}
}

void CurveRTCHost::unregisterEnvVariable(const char* propertyName) {
auto i = RTCEnvVariables.find(propertyName);
Expand Down Expand Up @@ -922,6 +948,8 @@ void CurveRTCHost::initHeaderGetters() {
getMessageArrayVariableLDGImpl << " return {};\n";
setHeaderPlaceholder("$DYNAMIC_GETMESSAGEARRAYVARIABLE_LDG_IMPL", getMessageArrayVariableLDGImpl.str());
}
setHeaderPlaceholder("$DYNAMIC_AGENT_NAME", this->agentName);
setHeaderPlaceholder("$DYNAMIC_AGENT_STATE", this->agentState);
}
void CurveRTCHost::initDataBuffer() {
if (data_buffer_size == 0 || h_data_buffer) {
Expand Down
5 changes: 4 additions & 1 deletion src/flamegpu/simulation/detail/CUDAAgent.cu
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ void CUDAAgent::setPopulationData(const AgentVector& population, const std::stri
if (state_name == ModelData::DEFAULT_STATE) {
THROW exception::InvalidAgentState("Agent '%s' does not use the default state, so the state must be passed explicitly, "
"in CUDAAgent::setPopulationData()",
state_name.c_str(), population.getAgentName().c_str());
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Had this exception thrown whilst writing test, noticed it was wrong.

population.getAgentName().c_str());
} else {
THROW exception::InvalidAgentState("State '%s' was not found in agent '%s', "
"in CUDAAgent::setPopulationData()",
Expand Down Expand Up @@ -502,6 +502,9 @@ void CUDAAgent::addInstantitateRTCFunction(const AgentFunctionData& func, const
// Set Environment macro properties in curve
macro_env->mapRTCVariables(curve_header);

// Set the agent name/state
curve_header.registerAgent(this->agent_description.name, func.initial_state);

std::string header_filename = std::string(func.rtc_func_name).append("_impl");
if (function_condition)
header_filename.append("_condition");
Expand Down
45 changes: 44 additions & 1 deletion tests/python/runtime/test_device_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,17 @@ class DeviceAPITest(TestCase):
return flamegpu::ALIVE;
}
"""


agent_fn_check_agent_name_state = """
FLAMEGPU_AGENT_FUNCTION(check_agent_name_state, flamegpu::MessageNone, flamegpu::MessageNone){
FLAMEGPU->setVariable<int>("correct_name", static_cast<int>(FLAMEGPU->isAgent("agent")));
FLAMEGPU->setVariable<int>("wrong_name", static_cast<int>(FLAMEGPU->isAgent("agent3")));
FLAMEGPU->setVariable<int>("correct_state", static_cast<int>(FLAMEGPU->isState("state")));
FLAMEGPU->setVariable<int>("wrong_state", static_cast<int>(FLAMEGPU->isState("state5")));
return flamegpu::ALIVE;
}
"""

def test_agent_death_array(self):
model = pyflamegpu.ModelDescription("test_agent_death_array")
Expand Down Expand Up @@ -139,6 +150,7 @@ def test_array_set(self):
assert output_array[3] == 16 + j



def test_array_get(self):
model = pyflamegpu.ModelDescription("test_array_get")
agent = model.newAgent("agent_name")
Expand Down Expand Up @@ -184,4 +196,35 @@ def test_array_get(self):
assert instance.getVariableInt("a2") == 4 + j
assert instance.getVariableInt("a3") == 8 + j
assert instance.getVariableInt("a4") == 16 + j



def test_check_agent_name_state(self):
model = pyflamegpu.ModelDescription("test_array_get")
agent = model.newAgent("agent")
agent.newState("state")
agent.newState("state8")
agent.newVariableInt("correct_name", -1)
agent.newVariableInt("wrong_name", -1)
agent.newVariableInt("correct_state", -1)
agent.newVariableInt("wrong_state", -1)
# Do nothing, but ensure variables are made available on device
func = agent.newRTCFunction("some_function", self.agent_fn_check_agent_name_state)
model.newLayer().addAgentFunction(func)
# Init pop
init_population = pyflamegpu.AgentVector(agent, AGENT_COUNT)

# Setup Model
cudaSimulation = pyflamegpu.CUDASimulation(model)
cudaSimulation.setPopulationData(init_population, "state")
# Run 1 step to ensure data is pushed to device
cudaSimulation.step()
# Recover data from device
population = pyflamegpu.AgentVector(agent, AGENT_COUNT)
cudaSimulation.getPopulationData(population, "state")
# Check results are correct
assert len(population) == AGENT_COUNT
for instance in population:
assert instance.getVariableInt("correct_name") == 1
assert instance.getVariableInt("wrong_name") == 0
assert instance.getVariableInt("correct_state") == 1
assert instance.getVariableInt("wrong_state") == 0