Skip to content

Commit

Permalink
cpp client update, legacy bug fix in traffic light states
Browse files Browse the repository at this point in the history
  • Loading branch information
rf-ivtdai committed Jan 11, 2024
1 parent 4ccdea6 commit 9ff0510
Show file tree
Hide file tree
Showing 9 changed files with 269 additions and 40 deletions.
8 changes: 8 additions & 0 deletions invertedai_cpp/invertedai/data_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,14 @@ struct TrafficLightState {
std::string value;
};

/**
* Light recurrent state that contains the current state and ticks remaining in this state.
*/
struct LightRecurrentState {
int state;
int ticks_remaining;
};

/**
* Infractions committed by a given agent, as returned from invertedai::drive().
*/
Expand Down
70 changes: 55 additions & 15 deletions invertedai_cpp/invertedai/drive_request.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,33 @@ DriveRequest::DriveRequest(const std::string &body_str) {
}
this->recurrent_states_.push_back(recurrent_state);
}
this->traffic_lights_states_.clear();
for (const auto &element : this->body_json_["traffic_lights_states"]) {
TrafficLightState traffic_light_state = {
element[0],
element[1]
};
this->traffic_lights_states_.push_back(traffic_light_state);
if (this->body_json_["traffic_lights_states"].is_null()) {
this->traffic_lights_states_ = std::nullopt;
} else {
if (this->traffic_lights_states_.has_value()) {
this->traffic_lights_states_.value().clear();
} else {
this->traffic_lights_states_ = std::map<std::string, std::string>();
}
for (const auto &element : this->body_json_["traffic_lights_states"].items()) {
this->traffic_lights_states_.value()[element.key()] = element.value();
}
}
if (this->body_json_["light_recurrent_states"].is_null()) {
this->light_recurrent_states_ = std::nullopt;
} else {
if (this->light_recurrent_states_.has_value()) {
this->light_recurrent_states_.value().clear();
} else {
this->light_recurrent_states_ = std::vector<LightRecurrentState>();
}
for (const auto &element : this->body_json_["light_recurrent_states"]) {
LightRecurrentState light_recurrent_state = {
element[0],
element[1]
};
this->light_recurrent_states_.value().push_back(light_recurrent_state);
}
}
this->get_birdview_ = this->body_json_["get_birdview"].is_boolean()
? this->body_json_["get_birdview"].get<bool>()
Expand Down Expand Up @@ -89,12 +109,24 @@ void DriveRequest::refresh_body_json_() {
this->body_json_["recurrent_states"].push_back(elements);
}
this->body_json_["traffic_lights_states"].clear();
for (const TrafficLightState &traffic_light_state : this->traffic_lights_states_) {
json element = {
traffic_light_state.id,
traffic_light_state.value
};
this->body_json_["traffic_lights_states"].push_back(element);
if (this->traffic_lights_states_.has_value()) {
for (const auto &pair : this->traffic_lights_states_.value()) {
this->body_json_["traffic_lights_states"][pair.first] = pair.second;
}
} else {
this->body_json_["traffic_lights_states"] = nullptr;
}
this->body_json_["light_recurrent_states"].clear();
if (this->light_recurrent_states_.has_value()) {
for (const LightRecurrentState &light_recurrent_state : this->light_recurrent_states_.value()) {
json element = {
light_recurrent_state.state,
light_recurrent_state.ticks_remaining
};
this->body_json_["light_recurrent_states"].push_back(element);
}
} else {
this->body_json_["light_recurrent_states"] = nullptr;
}
this->body_json_["get_birdview"] = this->get_birdview_;
this->body_json_["get_infractions"] = this->get_infractions_;
Expand Down Expand Up @@ -149,14 +181,18 @@ std::vector<AgentAttributes> DriveRequest::agent_attributes() const {
return this->agent_attributes_;
};

std::vector<TrafficLightState> DriveRequest::traffic_lights_states() const {
std::optional<std::map<std::string, std::string>> DriveRequest::traffic_lights_states() const {
return this->traffic_lights_states_;
};

std::vector<std::vector<double>> DriveRequest::recurrent_states() const {
return this->recurrent_states_;
};

std::optional<std::vector<LightRecurrentState>> DriveRequest::light_recurrent_states() const {
return this->light_recurrent_states_;
};

bool DriveRequest::get_birdview() const {
return this->get_birdview_;
}
Expand Down Expand Up @@ -193,10 +229,14 @@ void DriveRequest::set_agent_attributes(const std::vector<AgentAttributes> &agen
this->agent_attributes_ = agent_attributes;
}

void DriveRequest::set_traffic_lights_states(const std::vector<TrafficLightState> &traffic_lights_states) {
void DriveRequest::set_traffic_lights_states(const std::map<std::string, std::string> &traffic_lights_states) {
this->traffic_lights_states_ = traffic_lights_states;
}

void DriveRequest::set_light_recurrent_states(const std::vector<LightRecurrentState> &light_recurrent_states) {
this->light_recurrent_states_ = light_recurrent_states;
}

void DriveRequest::set_recurrent_states(const std::vector<std::vector<double>> &recurrent_states) {
this->recurrent_states_ = recurrent_states;
}
Expand Down
17 changes: 14 additions & 3 deletions invertedai_cpp/invertedai/drive_request.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <optional>
#include <string>
#include <vector>
#include <map>

#include "externals/json.hpp"

Expand All @@ -20,7 +21,8 @@ class DriveRequest {
std::string location_;
std::vector<AgentState> agent_states_;
std::vector<AgentAttributes> agent_attributes_;
std::vector<TrafficLightState> traffic_lights_states_;
std::optional<std::map<std::string, std::string>> traffic_lights_states_;
std::optional<std::vector<LightRecurrentState>> light_recurrent_states_;
std::vector<std::vector<double>> recurrent_states_;
bool get_birdview_;
bool get_infractions_;
Expand Down Expand Up @@ -73,11 +75,15 @@ class DriveRequest {
/**
* Get the states of traffic lights.
*/
std::vector<TrafficLightState> traffic_lights_states() const;
std::optional<std::map<std::string, std::string>> traffic_lights_states() const;
/**
* Get the recurrent states for all agents.
*/
std::vector<std::vector<double>> recurrent_states() const;
/**
* Get light recurrent states for all light groups in location.
*/
std::optional<std::vector<LightRecurrentState>> light_recurrent_states() const;
/**
* Check whether to return an image visualizing the simulation state.
*/
Expand Down Expand Up @@ -126,13 +132,18 @@ class DriveRequest {
* traffic light for which no state is provided will be ignored by the agents.
*/
void set_traffic_lights_states(
const std::vector<TrafficLightState> &traffic_lights_states);
const std::map<std::string, std::string> &traffic_lights_states);
/**
* Set the recurrent states for all agents, obtained from the
* previous call to drive() or initialize().
*/
void set_recurrent_states(
const std::vector<std::vector<double>> &recurrent_states);
/**
* Set light recurrent states for all light groups in location,
*/
void set_light_recurrent_states(
const std::vector<LightRecurrentState> &light_recurrent_states);
/**
* Set whether to return an image visualizing the simulation state.
* This is very slow and should only be used for debugging.
Expand Down
64 changes: 64 additions & 0 deletions invertedai_cpp/invertedai/drive_response.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,34 @@ DriveResponse::DriveResponse(const std::string &body_str) {
}
this->recurrent_states_.push_back(recurrent_state);
}
if (this->traffic_lights_states_.has_value()) {
this->traffic_lights_states_.value().clear();
} else {
this->traffic_lights_states_ = std::map<std::string, std::string>();
}
if (this->body_json_["traffic_lights_states"].is_null()) {
this->traffic_lights_states_ = std::nullopt;
} else {
for (const auto &element : this->body_json_["traffic_lights_states"].items()) {
this->traffic_lights_states_.value()[element.key()] = element.value();
}
}
if (this->light_recurrent_states_.has_value()) {
this->light_recurrent_states_.value().clear();
} else {
this->light_recurrent_states_ = std::vector<LightRecurrentState>();
}
if (this->body_json_["light_recurrent_states"].is_null()) {
this->light_recurrent_states_ = std::nullopt;
} else {
for (const auto &element : this->body_json_["light_recurrent_states"]) {
LightRecurrentState light_recurrent_state = {
element[0],
element[1]
};
this->light_recurrent_states_.value().push_back(light_recurrent_state);
}
}
this->birdview_.clear();
for (const auto &element : this->body_json_["birdview"]) {
this->birdview_.push_back(element);
Expand Down Expand Up @@ -71,6 +99,26 @@ void DriveResponse::refresh_body_json_() {
}
this->body_json_["recurrent_states"].push_back(elements);
}
this->body_json_["traffic_lights_states"].clear();
if (this->traffic_lights_states_.has_value()) {
for (const auto &pair : this->traffic_lights_states_.value()) {
this->body_json_["traffic_lights_states"][pair.first] = pair.second;
}
} else {
this->body_json_["traffic_lights_states"] = nullptr;
}
this->body_json_["light_recurrent_states"].clear();
if (this->light_recurrent_states_.has_value()) {
for (const LightRecurrentState &light_recurrent_state : this->light_recurrent_states_.value()) {
json element = {
light_recurrent_state.state,
light_recurrent_state.ticks_remaining
};
this->body_json_["light_recurrent_states"].push_back(element);
}
} else {
this->body_json_["light_recurrent_states"] = nullptr;
}
this->body_json_["birdview"].clear();
for (unsigned char element : this->birdview_) {
this->body_json_["birdview"].push_back(element);
Expand Down Expand Up @@ -105,6 +153,14 @@ std::vector<std::vector<double>> DriveResponse::recurrent_states() const {
return this->recurrent_states_;
}

std::optional<std::map<std::string, std::string>> DriveResponse::traffic_lights_states() const {
return this->traffic_lights_states_;
}

std::optional<std::vector<LightRecurrentState>> DriveResponse::light_recurrent_states() const {
return this->light_recurrent_states_;
}

std::vector<unsigned char> DriveResponse::birdview() const {
return this->birdview_;
}
Expand All @@ -129,6 +185,14 @@ void DriveResponse::set_recurrent_states(const std::vector<std::vector<double>>
this->recurrent_states_ = recurrent_states;
}

void DriveResponse::set_traffic_lights_states(const std::map<std::string, std::string> &traffic_lights_states) {
this->traffic_lights_states_ = traffic_lights_states;
}

void DriveResponse::set_light_recurrent_states(const std::vector<LightRecurrentState> &light_recurrent_states) {
this->light_recurrent_states_ = light_recurrent_states;
}

void DriveResponse::set_birdview(const std::vector<unsigned char> &birdview) {
this->birdview_ = birdview;
}
Expand Down
22 changes: 22 additions & 0 deletions invertedai_cpp/invertedai/drive_response.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
#include "data_utils.h"
#include "externals/json.hpp"

#include <map>
#include <string>
#include <vector>
#include <optional>

using json = nlohmann::json;

Expand All @@ -16,6 +18,8 @@ class DriveResponse {
std::vector<AgentState> agent_states_;
std::vector<bool> is_inside_supported_area_;
std::vector<std::vector<double>> recurrent_states_;
std::optional<std::map<std::string, std::string>> traffic_lights_states_;
std::optional<std::vector<LightRecurrentState>> light_recurrent_states_;
std::vector<unsigned char> birdview_;
std::vector<InfractionIndicator> infraction_indicators_;
std::string model_version_;
Expand Down Expand Up @@ -47,6 +51,14 @@ class DriveResponse {
* Get the recurrent states for all agents.
*/
std::vector<std::vector<double>> recurrent_states() const;
/**
* Get the states of traffic lights.
*/
std::optional<std::map<std::string, std::string>> traffic_lights_states() const;
/**
* Get light recurrent states for all light groups in location.
*/
std::optional<std::vector<LightRecurrentState>> light_recurrent_states() const;
/**
* If get_birdview was set, this contains the resulting image.
*/
Expand Down Expand Up @@ -78,6 +90,16 @@ class DriveResponse {
*/
void set_recurrent_states(
const std::vector<std::vector<double>> &recurrent_states);
/**
* Set the states of traffic lights.
*/
void set_traffic_lights_states(
const std::map<std::string, std::string> &traffic_lights_states);
/**
* Set light recurrent states for all light groups in location.
*/
void set_light_recurrent_states(
const std::vector<LightRecurrentState> &light_recurrent_states);
/**
* Set birdview.
*/
Expand Down
31 changes: 12 additions & 19 deletions invertedai_cpp/invertedai/initialize_request.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,15 @@ InitializeRequest::InitializeRequest(const std::string &body_str) {
this->agent_attributes_.push_back(agent_attribute);
}
this->traffic_light_state_history_.clear();
for (const auto &elements : this->body_json_["traffic_light_state_history"]) {
std::vector<TrafficLightState> traffic_light_states;
traffic_light_states.clear();
for (const auto &element : elements) {
TrafficLightState traffic_light_state = {
element[0],
element[1]
};
traffic_light_states.push_back(traffic_light_state);
std::vector<std::map<std::string, std::string>> traffic_light_states;
for (const auto &element : this->body_json_["traffic_light_state_history"]) {
std::map<std::string, std::string> light_states;
for (const auto &pair : element.items()) {
light_states[pair.key()] = pair.value();
}
this->traffic_light_state_history_.push_back(traffic_light_states);
traffic_light_states.push_back(light_states);
}
this->traffic_light_state_history_ = traffic_light_states;
this->location_of_interest_ = this->body_json_["location_of_interest"].is_null()
? std::nullopt
: std::optional<std::pair<double, double>>{this->body_json_["location_of_interest"]};
Expand Down Expand Up @@ -82,15 +79,11 @@ void InitializeRequest::refresh_body_json_() {
this->body_json_["agent_attributes"].push_back(element);
}
this->body_json_["traffic_light_state_history"].clear();
for (const std::vector<TrafficLightState> &traffic_light_states : this->traffic_light_state_history_) {
for (const std::map<std::string, std::string> &traffic_light_states : this->traffic_light_state_history_) {
json elements;
elements.clear();
for (const TrafficLightState &traffic_light_state : traffic_light_states) {
json element = {
traffic_light_state.id,
traffic_light_state.value
};
elements.push_back(element);
for (const auto &pair : traffic_light_states) {
elements[pair.first] = pair.second;
}
this->body_json_["traffic_light_state_history"].push_back(elements);
}
Expand Down Expand Up @@ -138,7 +131,7 @@ std::vector<AgentAttributes> InitializeRequest::agent_attributes() const {
return this->agent_attributes_;
}

std::vector<std::vector<TrafficLightState>> InitializeRequest::traffic_light_state_history() const {
std::vector<std::map<std::string, std::string>> InitializeRequest::traffic_light_state_history() const {
return this->traffic_light_state_history_;
}

Expand Down Expand Up @@ -178,7 +171,7 @@ void InitializeRequest::set_agent_attributes(const std::vector<AgentAttributes>
this->agent_attributes_ = agent_attributes;
}

void InitializeRequest::set_traffic_light_state_history(const std::vector<std::vector<TrafficLightState>>&traffic_light_state_history) {
void InitializeRequest::set_traffic_light_state_history(const std::vector<std::map<std::string, std::string>>&traffic_light_state_history) {
this->traffic_light_state_history_ = traffic_light_state_history;
}

Expand Down
Loading

0 comments on commit 9ff0510

Please sign in to comment.