Skip to content
This repository has been archived by the owner on Mar 17, 2020. It is now read-only.

Commit

Permalink
Adapt the code to train on Three map
Browse files Browse the repository at this point in the history
There was some changes on getting the Region information to generate the State representation for Protobuf.
  • Loading branch information
BlueDi committed Jun 24, 2019
1 parent da72db8 commit 497a7e5
Show file tree
Hide file tree
Showing 7 changed files with 42 additions and 40 deletions.
Binary file modified bandana/TournamentRunner.jar
Binary file not shown.
Binary file modified bandana/agents/DeepDip.jar
Binary file not shown.
Binary file modified bandana/agents/DumbBot.jar
Binary file not shown.
9 changes: 9 additions & 0 deletions bandana/src/main/java/cruz/agents/DeepDip.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ public class DeepDip extends DumbBot {
// The OpenAI Adapter contains the necessary functions and fields to make the connection to the Open AI environment
private OpenAIAdapter openAIAdapter;
private Logger logger = new Logger();
private String winner = null;

private DeepDip(String name, int finalYear, String logPath) {
super(name, finalYear, logPath);
Expand Down Expand Up @@ -195,6 +196,7 @@ private List<MTOOrder> getMTOOrders(List<Order> orders) {

@Override
public void handleSlo(String winner) {
this.winner = winner;
this.openAIAdapter.setWinner(winner);
if (this.me.getName().equals(winner)) {
System.out.println("GAME RESULT: " + this.me + " won with a solo victory.");
Expand All @@ -207,10 +209,17 @@ public void handleSlo(String winner) {
@Override
public void handleSMR(String[] message) {
GameResult gameResult = new GameResult(message, 2);

if (this.me.getControlledRegions().size() < 1) {
this.openAIAdapter.setWinner("eliminated");
} else if (this.winner == null) {
this.openAIAdapter.setWinner("draw");
}
this.openAIAdapter.endOfGame(gameResult);

System.out.println("END GAME: " + Arrays.toString(message));
super.handleSMR(message);
this.winner = null;
}

public Logger getLogger() {
Expand Down
50 changes: 18 additions & 32 deletions bandana/src/main/java/cruz/agents/OpenAIAdapter.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ public class OpenAIAdapter {
public static final int INVALID_DEAL_REWARD = -10;

/** Reward given for winning the game. */
public static final int SC_WIN = 5;
public static final int SC_WIN = 8;

/** The OpenAINegotiator instance to which this adapter is connected. */
public OpenAINegotiator agent;
Expand Down Expand Up @@ -184,46 +184,48 @@ private void generatePowerNameToIntMap() {
}
}

/**
* First process all Provinces.
* Then add the Owners and Units of each Province.
* Then add the created Provinces to the Observation.
*/
private ProtoMessage.ObservationData generateObservationData() {
ProtoMessage.ObservationData.Builder observationDataBuilder = ProtoMessage.ObservationData.newBuilder();
Map<String, ProtoMessage.ProvinceData.Builder> nameToProvinceDataBuilder = new HashMap<>();

String agent_name = (this.agent2 == null)? this.agent.me.getName() : this.agent2.getMe().getName();
observationDataBuilder.setPlayer(powerNameToInt.get(agent_name));

// FIRST PROCESS ALL PROVINCES
Vector<Province> provinces = (this.agent2 == null) ? this.agent.game.getProvinces() : this.agent2.getGame().getProvinces();
Vector<Region> regions = (this.agent2 == null) ? this.agent.game.getRegions() : this.agent2.getGame().getRegions();
int id = 1;
for (Province p : provinces) {
for (Region r : regions) {
ProtoMessage.ProvinceData.Builder provinceDataBuilder = ProtoMessage.ProvinceData.newBuilder();
int isSc = p.isSC() ? 1 : 0;
int isSc = r.getProvince().isSC() ? 1 : 0;

provinceDataBuilder.setId(id);
provinceDataBuilder.setSc(isSc);

nameToProvinceDataBuilder.put(p.getName(), provinceDataBuilder);
nameToProvinceDataBuilder.put(r.getName(), provinceDataBuilder);

id++;
}

// THEN ADD THE OWNERS & UNITS OF EACH PROVINCE
List<Power> powers = (this.agent2 == null)? this.agent.game.getPowers():this.agent2.getGame().getPowers();
for (Power pow : powers) {
for (Province p : pow.getOwnedSCs()) {
// Get the correspondent province builder and add the current owner of the province
ProtoMessage.ProvinceData.Builder provinceDataBuilder = nameToProvinceDataBuilder.get(p.getName());
provinceDataBuilder.setOwner(powerNameToInt.get(pow.getName()));
for (Region r : p.getRegions()) {
ProtoMessage.ProvinceData.Builder provinceDataBuilder = nameToProvinceDataBuilder.get(r.getName());
provinceDataBuilder.setOwner(powerNameToInt.get(pow.getName()));
}
}

for (Region r : pow.getControlledRegions()) {
Province p = r.getProvince();
ProtoMessage.ProvinceData.Builder provinceDataBuilder = nameToProvinceDataBuilder.get(p.getName());
ProtoMessage.ProvinceData.Builder provinceDataBuilder = nameToProvinceDataBuilder.get(r.getName());
provinceDataBuilder.setOwner(powerNameToInt.get(pow.getName()));
provinceDataBuilder.setUnit(powerNameToInt.get(pow.getName()));
}
}

// ADD CREATED PROVINCES TO OBSERVATION
for (Map.Entry<String, ProtoMessage.ProvinceData.Builder> entry : nameToProvinceDataBuilder.entrySet()) {
observationDataBuilder.addProvinces(entry.getValue().build());
}
Expand All @@ -247,7 +249,7 @@ private void rewardFunction() {
String agent_name = (this.agent2 == null)? this.agent.me.getName() : this.agent2.getMe().getName();
if (agent_name.equals(this.winner)) {
this.reward += SC_WIN;
} else {
} else if (!this.winner.equals("draw")){
this.reward -= SC_WIN;
}
} else {
Expand All @@ -259,8 +261,6 @@ private BasicDeal generateDeal(ProtoMessage.DealData dealData) {
List<DMZ> dmzs = new ArrayList<>();
List<OrderCommitment> ocs = new ArrayList<>();


// Add MY order commitment
Province ourStartProvince = this.agent.game.getProvinces().get(dealData.getOurMove().getStartProvince());
Province ourDestinationProvince = this.agent.game.getProvinces().get(dealData.getOurMove().getDestinationProvince());

Expand Down Expand Up @@ -303,24 +303,14 @@ private BasicDeal generateDeal(ProtoMessage.DealData dealData) {
private List<Order> generateOrders(ProtoMessage.OrdersData ordersData) {
List<Order> orders = new ArrayList<>();
List<ProtoMessage.OrderData> support_orders = new ArrayList<>();
List<Province> game_provinces = this.agent2.getGame().getProvinces();
List<Region> game_regions = this.agent2.getGame().getRegions();

for (ProtoMessage.OrderData order : ordersData.getOrdersList()) {
if (order.getStart() == -1){
break;
}
Province start_province = game_provinces.get(order.getStart());
Province destination_province = game_provinces.get(order.getDestination());

Region start = game_regions.stream()
.filter(r -> r.getProvince().getName().equals(start_province.getName()))
.findAny()
.orElse(null);
Region destination = game_regions.stream()
.filter(r -> r.getProvince().getName().equals(destination_province.getName()))
.findAny()
.orElse(null);
Region start = game_regions.get(order.getStart());
Region destination = game_regions.get(order.getDestination());

if (order.getAction() == 0) {
orders.add(new HLDOrder(this.agent2.getMe(), start));
Expand All @@ -331,8 +321,6 @@ private List<Order> generateOrders(ProtoMessage.OrdersData ordersData) {
support_orders.add(order);
}
} else {
//System.err.println("WRONG BORDER: For order of type " + order.getAction() + ", the destination " + destination + " is not a border with current province " + start);
//this.addReward(INVALID_DEAL_REWARD);
orders.add(new HLDOrder(this.agent2.getMe(), start));
}
}
Expand All @@ -345,8 +333,6 @@ private List<Order> generateOrders(ProtoMessage.OrdersData ordersData) {
.findAny()
.orElse(null);
if (order_to_support == null) {
//System.err.println("ORDER TO SUPPORT NOT FOUND");
//this.addReward(INVALID_DEAL_REWARD);
orders.add(new HLDOrder(this.agent2.getMe(), start));
} else if (order_to_support instanceof MTOOrder) {
orders.add(new SUPMTOOrder(this.agent2.getMe(), start, (MTOOrder) order_to_support));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@
public class TournamentRunner {
final static boolean MODE = false; //Strategy/false vs Negotiation/true
final static int REMOTE_DEBUG = 0; //Set whether I want to remote debug the OpenAI jar or not
private final static String GAME_MAP = "small"; //Game map can be 'standard', 'mini' or 'small'
private final static String GAME_MAP = "three"; //Game map can be 'standard', or 'mini', or 'small', or 'three'
private final static String FINAL_YEAR = "2000"; //The year after which the agents in each game are supposed to propose a draw to each other.

// Using a custom map to define how many players are there on each custom map
private final static Map<String, Integer> mapToNumberOfPlayers = new HashMap<String, Integer>() {{
put("standard", 7);
put("small", 2);
put("three", 3);
}};

// Main folder where all the logs are stored. For each tournament a new folder will be created inside this folder
Expand Down Expand Up @@ -79,13 +80,12 @@ public static void run(int numberOfGames, int moveTimeLimit, int retreatTimeLimi
//Create a list of ScoreCalculators to determine how the players should be ranked in the tournament.
ArrayList<ScoreCalculator> scoreCalculators = new ArrayList<ScoreCalculator>();

if(GAME_MAP.toLowerCase().equals("standard")) {
if (GAME_MAP.toLowerCase().equals("standard")) {
scoreCalculators.add(new SoloVictoryCalculator());
scoreCalculators.add(new SupplyCenterCalculator());
scoreCalculators.add(new PointsCalculator());
scoreCalculators.add(new RankCalculator());
}
else {
} else {
scoreCalculators.add(new RankCalculator());
}

Expand All @@ -109,7 +109,7 @@ public static void run(int numberOfGames, int moveTimeLimit, int retreatTimeLimi
name = "DeepDip";
command = deepDipCommand;
} else {
name = "DumbBot";
name = "DumbBot " + i;
command = dumbBotCommand;
}

Expand Down Expand Up @@ -159,7 +159,7 @@ public static void run(int numberOfGames, int moveTimeLimit, int retreatTimeLimi
playerProcess.destroy();
}

if(tournamentObserver != null){
if (tournamentObserver != null) {
tournamentObserver.exit();
}

Expand Down
11 changes: 9 additions & 2 deletions gym-diplomacy/gym_diplomacy/envs/diplomacy_strategy_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,12 @@

### CONSTANTS
NUMBER_OF_ACTIONS = 3
NUMBER_OF_PLAYERS = 2
NUMBER_OF_PROVINCES = 19#8#75
MAPS = ['mini', 'small', 'three', 'standard']
CURRENT_MAP = MAPS[2]
PLAYERS = {'mini':2, 'small':2, 'three':3, 'standard':7}
NUMBER_OF_PLAYERS = PLAYERS[CURRENT_MAP]
REGIONS = {'mini':10, 'small':19, 'three':37, 'standard':121}
NUMBER_OF_PROVINCES = REGIONS[CURRENT_MAP]


def observation_data_to_observation(observation_data: proto_message_pb2.ObservationData) -> np.array:
Expand Down Expand Up @@ -58,6 +62,9 @@ def observation_data_to_observation(observation_data: proto_message_pb2.Observat
done = observation_data.done
info = {}

if len(get_player_units(observation)) < 1:
done = True

return observation, reward, done, info


Expand Down

0 comments on commit 497a7e5

Please sign in to comment.