Skip to content
This repository has been archived by the owner on Mar 11, 2021. It is now read-only.

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
tommadams committed Jan 11, 2020
1 parent 50f6852 commit 64d5410
Show file tree
Hide file tree
Showing 8 changed files with 34 additions and 24 deletions.
11 changes: 9 additions & 2 deletions cc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,11 @@ minigo_cc_library(
minigo_cc_library(
name = "logging",
srcs = ["logging.cc"],
hdrs = ["logging.h"],
linkopts = select({
"@bazel_tools//src/conditions:windows": ["ws2_32.lib"],
"//conditions:default": ["-lpthread"],
}),
hdrs = ["logging.h"],
deps = [
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/debugging:stacktrace",
Expand All @@ -148,11 +148,13 @@ minigo_cc_library(
"//cc/async:thread",
"//cc/async:thread_safe_queue",
"//cc/file",
"//cc/model",
"//cc/model:batching_model",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/time",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span",
],
)
Expand All @@ -161,10 +163,11 @@ minigo_cc_library(
# file named `downloaded`. Rename it to `json.h`.
genrule(
name = "json_h",
outs = ["json.h"],
srcs = ["@com_github_nlohmann_json_single_header//file"],
outs = ["json.h"],
cmd = "cp $< $@",
)

cc_library(
name = "json",
hdrs = [":json_h"],
Expand Down Expand Up @@ -358,6 +361,7 @@ minigo_cc_test_9_only(
size = "small",
srcs = ["mcts_tree_test.cc"],
deps = [
":base",
":mcts",
":position",
":random",
Expand Down Expand Up @@ -404,9 +408,12 @@ minigo_cc_test(
srcs = ["pass_alive_test.cc"],
deps = [
":base",
":logging",
":position",
":random",
":test_utils",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_googletest//:gtest_main",
],
)
Expand Down
5 changes: 1 addition & 4 deletions cc/dual_net/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,7 @@ licenses(["notice"]) # Apache License 2.0

load(
"//cc/config:minigo.bzl",
"minigo_cc_binary",
"minigo_cc_library",
"minigo_cc_test",
"minigo_cc_test_19_only",
"minigo_cc_test_9_only",
"minigo_engine_copts",
)
Expand Down Expand Up @@ -72,8 +69,8 @@ minigo_cc_library(
deps = [
"//cc:base",
"//cc:logging",
"//cc/file:path",
"//cc/file",
"//cc/file:path",
"//cc/model",
"//cc/model:factory",
"//cc/platform",
Expand Down
2 changes: 1 addition & 1 deletion cc/dual_net/tpu_dual_net.cc
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ TpuDualNetFactory::LoadedModel TpuDualNetFactory::GetModel(
coded_stream.ConsumedEntireMessage());

// Find the data type of the input features.
tensorflow::DataType dt;
tensorflow::DataType dt = tensorflow::DT_INVALID;
const auto& input_type = def.metadata.Get<std::string>("input_type");
if (input_type == "bool") {
dt = tensorflow::DT_BOOL;
Expand Down
18 changes: 9 additions & 9 deletions cc/minigui_gtp_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ MiniguiGtpClient::MiniguiGtpClient(
const std::string& model_path, const Game::Options& game_options,
const MctsPlayer::Options& player_options,
const GtpClient::Options& client_options)
: GtpClient(std::move(device), inference_cache, model_path,
game_options, player_options, client_options) {
: GtpClient(std::move(device), inference_cache, model_path, game_options,
player_options, client_options) {
RegisterCmd("echo", &MiniguiGtpClient::HandleEcho);
RegisterCmd("genmove", &MiniguiGtpClient::HandleGenmove);
RegisterCmd("play", &MiniguiGtpClient::HandlePlay);
Expand All @@ -59,8 +59,8 @@ MiniguiGtpClient::MiniguiGtpClient(
auto worker_options = player_options;
worker_options.virtual_losses = 1;
win_rate_evaluator_ = absl::make_unique<WinRateEvaluator>(
num_workers, num_win_rate_evals, device_, inference_cache,
model_path, game_options, worker_options);
num_workers, num_win_rate_evals, device_, inference_cache, model_path,
game_options, worker_options);
}

MiniguiGtpClient::~MiniguiGtpClient() = default;
Expand Down Expand Up @@ -455,9 +455,9 @@ MiniguiGtpClient::WinRateEvaluator::WinRateEvaluator(
batcher_ = absl::make_unique<BatchingModelFactory>(device, 2);
for (int i = 0; i < num_workers; ++i) {
auto game = absl::make_unique<Game>("b", "w", game_options);
auto player = absl::make_unique<MctsPlayer>(
batcher_->NewModel(model_path), inference_cache, game.get(),
player_options);
auto player = absl::make_unique<MctsPlayer>(batcher_->NewModel(model_path),
inference_cache, game.get(),
player_options);
workers_.push_back(absl::make_unique<Worker>(
std::move(game), std::move(player), &eval_queue_));
workers_.back()->Start();
Expand Down Expand Up @@ -539,6 +539,7 @@ MiniguiGtpClient::WinRateEvaluator::Worker::~Worker() {
}

void MiniguiGtpClient::WinRateEvaluator::Worker::Prepare() {
absl::MutexLock lock(&mutex_);
BatchingModelFactory::StartGame(player_->model(), player_->model());
}

Expand All @@ -552,8 +553,7 @@ void MiniguiGtpClient::WinRateEvaluator::Worker::EvalAsync(
void MiniguiGtpClient::WinRateEvaluator::Worker::Run() {
for (;;) {
absl::MutexLock lock(&mutex_);
mutex_.Await(absl::Condition(
&pending_, &absl::optional<VariationTree::Node*>::has_value));
mutex_.Await(absl::Condition(this, &Worker::has_pending_value));
auto* node = *pending_;
pending_.reset();
if (node == nullptr) {
Expand Down
4 changes: 4 additions & 0 deletions cc/minigui_gtp_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,10 @@ class MiniguiGtpClient : public GtpClient {
private:
void Run() override;

bool has_pending_value() const EXCLUSIVE_LOCKS_REQUIRED(&mutex_) {
return pending_.has_value();
}

absl::Mutex mutex_;
absl::optional<VariationTree::Node*> pending_ GUARDED_BY(&mutex_);
std::unique_ptr<Game> game_;
Expand Down
3 changes: 3 additions & 0 deletions cc/model/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ minigo_cc_library(
hdrs = ["factory.h"],
deps = [
":model",
"//cc:logging",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:variant",
Expand Down Expand Up @@ -162,6 +163,8 @@ minigo_cc_binary(
":model",
"//cc:base",
"//cc:init",
"//cc:logging",
"//cc:position",
"//cc:random",
"//cc:symmetries",
"@com_google_absl//absl/strings",
Expand Down
10 changes: 5 additions & 5 deletions cc/model/factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ class LogModelProperty {
std::ostream* os_;
};

std::ostream& operator<<(std::ostream& os, const ModelProperty& p) {
absl::visit(LogModelProperty(&os), p);
return os;
}

} // namespace

std::string ModelMetadata::DebugString() const {
Expand All @@ -50,11 +55,6 @@ std::string ModelMetadata::DebugString() const {
return absl::StrCat("{", absl::StrJoin(items, ", "), "}");
}

std::ostream& operator<<(std::ostream& os, const ModelProperty& p) {
absl::visit(LogModelProperty(&os), p);
return os;
}

ModelFactory::~ModelFactory() = default;

} // namespace minigo
5 changes: 2 additions & 3 deletions cc/model/factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,14 @@
#include "absl/container/flat_hash_map.h"
#include "absl/strings/string_view.h"
#include "absl/types/variant.h"
#include "cc/logging.h"
#include "cc/model/model.h"

namespace minigo {

using ModelProperty =
absl::variant<std::string, bool, int64_t, uint64_t, float>;

std::ostream& operator<<(std::ostream& os, const ModelProperty& p);

// Although the metadata is stored in the Minigo file as JSON, it is
// converted on load to a simpler representation to avoid pulling an entire
// JSON library into this header (at the time of writing the nlohmann::json
Expand All @@ -57,7 +56,7 @@ class ModelMetadata {
template <typename T>
const T& Get(absl::string_view key) const {
const auto& prop = impl_.at(key);
MG_DCHECK(absl::holds_alternative<T>(prop)) << prop;
MG_DCHECK(absl::holds_alternative<T>(prop)) << DebugString();
return absl::get<T>(prop);
}

Expand Down

0 comments on commit 64d5410

Please sign in to comment.