Skip to content

Commit

Permalink
chech docker image platform when pulling docker image (#1166)
Browse files Browse the repository at this point in the history
* chech docker image platform when pulling docker image

* chech docker image platform when pulling docker image * some fixes

* check docker image platform when pulling docker image * added case for docker client exception

* check docker image platform when pulling docker image * code review

---------

Co-authored-by: Andrey Balakshiy <andrey9594@yandex-team.ru>
Co-authored-by: Balakshiy Andrey <abalakshiy@marathonbet.ru>
  • Loading branch information
3 people authored Mar 18, 2024
1 parent 50e1686 commit d6376d1
Show file tree
Hide file tree
Showing 5 changed files with 299 additions and 16 deletions.
10 changes: 10 additions & 0 deletions lzy/execution-env/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,16 @@
<artifactId>junit</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-inline</artifactId>
<scope>test</scope>
</dependency>
</dependencies>

</project>
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@
import org.apache.commons.lang3.RandomStringUtils;
import org.apache.commons.lang3.StringUtils;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.*;
import java.util.stream.Collectors;

public record DockerEnvDescription(
Expand All @@ -18,7 +16,8 @@ public record DockerEnvDescription(
List<String> envVars, // In format <NAME>=<value>
@Nullable
String networkMode,
DockerClientConfig dockerClientConfig
DockerClientConfig dockerClientConfig,
Set<String> allowedPlatforms // In format os/arch like "linux/amd64". Empty means all are allowed
) {

public static Builder newBuilder() {
Expand All @@ -32,6 +31,7 @@ public String toString() {
", image='" + image + '\'' +
", needGpu=" + needGpu +
", networkMode=" + networkMode +
", allowedPlatforms=" + String.join(", ", allowedPlatforms) +
", mounts=[" + mounts.stream()
.map(it -> it.source() + " -> " + it.target() + (it.isRshared() ? " (R_SHARED)" : ""))
.collect(Collectors.joining(", ")) + "]" +
Expand All @@ -52,6 +52,7 @@ public static class Builder {
List<String> envVars = new ArrayList<>();
String networkMode = null;
DockerClientConfig dockerClientConfig;
Set<String> allowedPlatforms = new HashSet<>();

public Builder withName(String name) {
this.name = name;
Expand Down Expand Up @@ -93,13 +94,18 @@ public Builder withDockerClientConfig(DockerClientConfig dockerClientConfig) {
return this;
}

public Builder withAllowedPlatforms(Collection<String> allowedPlatforms) {
this.allowedPlatforms.addAll(allowedPlatforms);
return this;
}

public DockerEnvDescription build() {
if (StringUtils.isBlank(name)) {
name = "job-" + RandomStringUtils.randomAlphanumeric(5);
}
return new DockerEnvDescription(name, image, mounts, gpu, envVars, networkMode, dockerClientConfig);
return new DockerEnvDescription(name, image, mounts, gpu, envVars, networkMode, dockerClientConfig,
allowedPlatforms);
}

}

public record ContainerRegistryCredentials(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,19 @@
import ai.lzy.env.EnvironmentInstallationException;
import ai.lzy.env.logs.LogStream;
import com.github.dockerjava.api.DockerClient;
import com.github.dockerjava.api.async.ResultCallback;
import com.github.dockerjava.api.async.ResultCallbackTemplate;
import com.github.dockerjava.api.command.ExecCreateCmd;
import com.github.dockerjava.api.command.ExecCreateCmdResponse;
import com.github.dockerjava.api.command.InspectImageResponse;
import com.github.dockerjava.api.command.PullImageResultCallback;
import com.github.dockerjava.api.exception.DockerClientException;
import com.github.dockerjava.api.exception.DockerException;
import com.github.dockerjava.api.exception.NotFoundException;
import com.github.dockerjava.api.model.*;
import com.github.dockerjava.core.DockerClientImpl;
import com.github.dockerjava.httpclient5.ApacheDockerHttpClient;
import com.google.common.annotations.VisibleForTesting;
import io.github.resilience4j.core.IntervalFunction;
import io.github.resilience4j.retry.Retry;
import io.github.resilience4j.retry.RetryConfig;
Expand All @@ -28,6 +31,7 @@
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.atomic.AtomicInteger;
Expand All @@ -37,6 +41,8 @@ public class DockerEnvironment extends BaseEnvironment {
private static final Logger LOG = LogManager.getLogger(DockerEnvironment.class);
private static final long GB_AS_BYTES = 1073741824;
private static final String ROOT_USER_UID = "0";
private static final String NO_MATCHING_MANIFEST_ERROR = "no matching manifest";
private static final String NOT_MATCH_PLATFORM_ERROR = "was found but does not match the specified platform";

@Nullable
public String containerId = null;
Expand Down Expand Up @@ -275,12 +281,14 @@ public void close() throws Exception {
}
}

private void prepareImage(String image, LogStream out) throws Exception {
@VisibleForTesting
void prepareImage(String image, LogStream out) throws Exception {
try {
client.inspectImageCmd(image).exec();
var inspectImageResponse = client.inspectImageCmd(image).exec();
var msg = "Image %s exists".formatted(image);
LOG.info(msg);
out.log(msg);
checkPlatform(inspectImageResponse, out);
return;
} catch (NotFoundException ignored) {
var msg = "Image %s not found in cached images".formatted(image);
Expand All @@ -291,16 +299,67 @@ private void prepareImage(String image, LogStream out) throws Exception {
var msg = "Pulling image %s ...".formatted(image);
LOG.info(msg);
out.log(msg);
Set<String> allowedPlatforms = config.allowedPlatforms();
AtomicInteger pullingAttempt = new AtomicInteger(0);
retry.executeCallable(() -> {
try (var pullResponseItem = retry.executeCallable(() -> {
LOG.info("Pulling image {}, attempt {}", image, pullingAttempt.incrementAndGet());
final var pullingImage = client
.pullImageCmd(config.image())
.exec(new PullImageResultCallback());
return pullingImage.awaitCompletion();
});
if (allowedPlatforms.isEmpty()) {
return pullWithPlatform(image, null);
} else {
for (String platform : config.allowedPlatforms()) {
try {
return pullWithPlatform(image, platform);
} catch (DockerClientException e) {
String exceptionMessage = e.getMessage();
if (exceptionMessage.contains(NO_MATCHING_MANIFEST_ERROR) ||
exceptionMessage.contains(NOT_MATCH_PLATFORM_ERROR)) {
LOG.info("Cannot find image = {} for platform = {}: message = {}",
image, platform, exceptionMessage);
} else {
throw e;
}
}
}
}
return null;
}))
{
if (pullResponseItem == null) {
throw new RuntimeException("Cannot pull image for allowed platforms = %s".formatted(
String.join(", ", allowedPlatforms)));
}
}

msg = "Pulling image %s done".formatted(image);
LOG.info(msg);
out.log(msg);
}

private ResultCallback.Adapter<PullResponseItem> pullWithPlatform(String image, @Nullable String platform)
throws InterruptedException
{
var pullingImage = client.pullImageCmd(image);
if (platform != null) {
pullingImage = pullingImage.withPlatform(platform);
}
return pullingImage.exec(new PullImageResultCallback()).awaitCompletion();
}

private void checkPlatform(InspectImageResponse inspectImageResponse, LogStream out) {
Set<String> allowedPlatforms = config.allowedPlatforms();
if (allowedPlatforms.isEmpty()) {
return;
}

String platform = inspectImageResponse.getOs() + "/" + inspectImageResponse.getArch();
if (!allowedPlatforms.contains(platform)) {
var allowedPlatformsStr = String.join(", ", allowedPlatforms);
var msg = "Image %s with platform = %s is not in allowed platforms = %s".formatted(
config.image(), platform, allowedPlatformsStr);
LOG.info(msg);
out.log(msg);

throw new RuntimeException(msg);
}
}
}
Loading

0 comments on commit d6376d1

Please sign in to comment.