Skip to content

Commit 0226e6e

Browse files
authored
add pip trusted hosts support (#1164)
1 parent a6d034e commit 0226e6e

File tree

2 files changed

+182
-21
lines changed

2 files changed

+182
-21
lines changed

lzy/execution-env/src/main/java/ai/lzy/env/aux/CondaPackageRegistry.java

Lines changed: 49 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,11 @@ public class CondaPackageRegistry {
2727
private static final String DEFAULT_PYPI_INDEX = "https://pypi.org/simple";
2828
private static final String CONDA_YAML_FILE = "conda-desc.yaml";
2929

30+
private static final String PIP_INDEX_URL_FLAG = "--index-url";
31+
private static final String PIP_EXTRA_INDEX_URL_FLAG = "--extra-index-url";
32+
private static final String PIP_TRUSTED_HOST_FLAG = "--trusted-host";
33+
private static final String PIP_NO_DEPS_FLAG = "--no-deps";
34+
3035
// TODO(artolord) remove this ugly hack after removing conda.yaml
3136
private static final Map<String, String> NAME_TO_PYTHON_VERSION = Map.of(
3237
"py37", "3.7",
@@ -95,7 +100,8 @@ private record CondaEnv(
95100
String pythonVersion,
96101
String pypiIndex,
97102
boolean noDeps,
98-
List<String> extraIndexUrls
103+
List<String> extraIndexUrls,
104+
List<String> trustedHosts
99105
) {}
100106

101107
/**
@@ -111,12 +117,13 @@ public String buildCondaYaml(String condaYaml) {
111117
throw new IllegalArgumentException("Cannot build env from yaml");
112118
}
113119

114-
return buildCondaYaml(env.packages, env.pythonVersion, env.pypiIndex, env.noDeps, env.extraIndexUrls);
120+
return buildCondaYaml(env.packages, env.pythonVersion, env.pypiIndex, env.noDeps, env.extraIndexUrls,
121+
env.trustedHosts);
115122
}
116123

117124
@Nullable
118125
private String buildCondaYaml(Map<String, Package> packages, String pythonVersion, String pypiIndex,
119-
boolean noDeps, List<String> extraIndexUrls)
126+
boolean noDeps, List<String> extraIndexUrls, List<String> trustedHosts)
120127
{
121128
try {
122129
var installedEnv = envs.values().stream()
@@ -150,15 +157,15 @@ private String buildCondaYaml(Map<String, Package> packages, String pythonVersio
150157
}
151158

152159
return buildYaml(new CondaEnv(installedEnv.name, packages, installedEnv.pythonVersion,
153-
pypiIndex, noDeps, extraIndexUrls));
160+
pypiIndex, noDeps, extraIndexUrls, trustedHosts));
154161
}
155162

156163
} catch (Exception e) {
157164
LOG.error("Error while building conda yaml for packages {}: ", packages, e);
158165
}
159166

160167
return buildYaml(new CondaEnv("py" + pythonVersion, packages, pythonVersion, pypiIndex, noDeps,
161-
extraIndexUrls));
168+
extraIndexUrls, trustedHosts));
162169
}
163170

164171
public void notifyInstalled(String condaYaml) {
@@ -217,6 +224,7 @@ CondaEnv build(String condaYaml) {
217224
String pypiIndex = null;
218225
boolean noDeps = false;
219226
List<String> extraIndexUrls = new ArrayList<>();
227+
List<String> trustedHosts = new ArrayList<>();
220228

221229
for (var dep : deps) {
222230
if (dep instanceof String) {
@@ -239,27 +247,34 @@ CondaEnv build(String condaYaml) {
239247
}
240248

241249
//noinspection unchecked
242-
for (var pipDep : (List<Object>) pipDeps) {
243-
if (!(pipDep instanceof String)) {
250+
for (var rawPipDep : (List<Object>) pipDeps) {
251+
if (!(rawPipDep instanceof String pipDep)) {
244252
return null;
245253
}
246254

247-
if (((String) pipDep).startsWith("--index-url")) {
248-
var parts = ((String) pipDep).split(" ");
249-
pypiIndex = parts.length > 1 ? parts[1] : null;
255+
if (pipDep.startsWith(PIP_INDEX_URL_FLAG)) {
256+
pypiIndex = parsePipOptionValue(PIP_INDEX_URL_FLAG, pipDep);
257+
continue;
258+
}
259+
if (pipDep.startsWith(PIP_EXTRA_INDEX_URL_FLAG)) {
260+
final var extraIndex = parsePipOptionValue(PIP_EXTRA_INDEX_URL_FLAG, pipDep);
261+
if (extraIndex != null) {
262+
extraIndexUrls.add(extraIndex);
263+
}
250264
continue;
251265
}
252-
if (((String) pipDep).startsWith("--extra-index-url")) {
253-
var parts = ((String) pipDep).split(" ");
254-
var extraIndex = parts.length > 1 ? parts[1] : null;
255-
extraIndexUrls.add(extraIndex);
266+
if (pipDep.startsWith(PIP_TRUSTED_HOST_FLAG)) {
267+
final var trustedHost = parsePipOptionValue(PIP_TRUSTED_HOST_FLAG, pipDep);
268+
if (trustedHost != null) {
269+
trustedHosts.add(trustedHost);
270+
}
256271
continue;
257272
}
258-
if (((String) pipDep).startsWith("--no-deps")) {
273+
if (pipDep.startsWith(PIP_NO_DEPS_FLAG)) {
259274
noDeps = true;
260275
}
261276

262-
var dat = ((String) pipDep).split(VERSION_REGEX, SPLIT_LIMIT);
277+
var dat = pipDep.split(VERSION_REGEX, SPLIT_LIMIT);
263278

264279
var pkgName = normalizePkgName(dat[0]);
265280
if (dat.length == 1) {
@@ -291,19 +306,33 @@ CondaEnv build(String condaYaml) {
291306
}
292307

293308
return new CondaEnv(name, pkgs, pythonVersion, pypiIndex == null ? DEFAULT_PYPI_INDEX : pypiIndex, noDeps,
294-
extraIndexUrls);
309+
extraIndexUrls, trustedHosts);
310+
}
311+
312+
@Nullable
313+
private String parsePipOptionValue(String optionName, String optionString) {
314+
String[] optionParts = optionString.split("\\s+", 2);
315+
if (optionParts.length == 1 || !optionParts[0].equals(optionName)) {
316+
LOG.warn("Unable to parse value for option '{}' from '{}'", optionName, optionString);
317+
return null;
318+
}
319+
return optionParts[1];
295320
}
296321

297322
private String buildYaml(CondaEnv env) {
298323
var pkgs = new ArrayList<>();
299-
pkgs.add("--index-url " + env.pypiIndex);
324+
pkgs.add(PIP_INDEX_URL_FLAG + " " + env.pypiIndex);
300325

301326
if (env.noDeps) {
302-
pkgs.add("--no-deps");
327+
pkgs.add(PIP_NO_DEPS_FLAG);
303328
}
304329

305330
for (var extraUrl: env.extraIndexUrls) {
306-
pkgs.add("--extra-index-url " + extraUrl);
331+
pkgs.add(PIP_EXTRA_INDEX_URL_FLAG + " " + extraUrl);
332+
}
333+
334+
for (var trustedHost: env.trustedHosts) {
335+
pkgs.add(PIP_TRUSTED_HOST_FLAG + " " + trustedHost);
307336
}
308337

309338
for (var p: env.packages.values()) {

lzy/execution-env/src/test/java/ai/lzy/env/CondaPackageRegistryTest.java

Lines changed: 133 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22

33

44
import ai.lzy.env.aux.CondaPackageRegistry;
5-
import ai.lzy.env.aux.SimpleBashEnvironment;
65
import ai.lzy.env.base.ProcessEnvironment;
76
import org.junit.Assert;
87
import org.junit.Before;
98
import org.junit.Test;
109

10+
import java.util.regex.Pattern;
11+
1112
public class CondaPackageRegistryTest {
1213

1314
private final CondaPackageRegistry condaPackageRegistry = new CondaPackageRegistry(new ProcessEnvironment());
@@ -244,6 +245,137 @@ public void testPipDependenciesNewDepWithoutVersion() {
244245
- serialzy>=1.0.0"""));
245246
}
246247

248+
@Test
249+
public void testPipWithIndexUrl() {
250+
String condaYaml = condaPackageRegistry.buildCondaYaml("""
251+
name: default
252+
dependencies:
253+
- python=3.9.15
254+
- pip
255+
- cloudpickle=1.0.0
256+
- pip:
257+
- numpy
258+
- --index-url https://pypi.ngc.nvidia.com
259+
- scipy
260+
- pylzy==1.0.0
261+
- serialzy>=1.0.0""");
262+
Assert.assertNotNull(condaYaml);
263+
Assert.assertTrue(
264+
"Index url is not in conda yaml",
265+
condaYaml.contains("--index-url https://pypi.ngc.nvidia.com"));
266+
}
267+
268+
@Test
269+
public void testPipWithoutExplicitIndexUrl() {
270+
String condaYaml = condaPackageRegistry.buildCondaYaml("""
271+
name: default
272+
dependencies:
273+
- python=3.9.15
274+
- pip
275+
- cloudpickle=1.0.0
276+
- pip:
277+
- numpy
278+
- scipy
279+
- pylzy==1.0.0
280+
- serialzy>=1.0.0""");
281+
Assert.assertNotNull(condaYaml);
282+
Assert.assertTrue(
283+
"Index url is not in conda yaml",
284+
condaYaml.contains("--index-url https://pypi.org/simple"));
285+
}
286+
287+
@Test
288+
public void testPipWithExtraIndexUrls() {
289+
String condaYaml = condaPackageRegistry.buildCondaYaml("""
290+
name: default
291+
dependencies:
292+
- python=3.9.15
293+
- pip
294+
- cloudpickle=1.0.0
295+
- pip:
296+
- numpy
297+
- --extra-index-url https://pypi.ngc.nvidia.com
298+
- --extra-index-url https://pypy.example.com
299+
- --extra-index-urlhttps://pypy.invalid-example.com
300+
- scipy
301+
- pylzy==1.0.0
302+
- serialzy>=1.0.0""");
303+
Assert.assertNotNull(condaYaml);
304+
Assert.assertTrue(
305+
"Extra index url is not in final conda yaml",
306+
condaYaml.contains("--extra-index-url https://pypi.ngc.nvidia.com"));
307+
Assert.assertTrue(
308+
"Extra index url is not in final conda yaml",
309+
condaYaml.contains("--extra-index-url https://pypy.example.com"));
310+
Assert.assertEquals(
311+
"Invalid extra index url is in final conda yaml",
312+
2,
313+
Pattern.compile("--extra-index-url").matcher(condaYaml).results().count());
314+
}
315+
316+
@Test
317+
public void testPipWithTrustedHosts() {
318+
String condaYaml = condaPackageRegistry.buildCondaYaml("""
319+
name: default
320+
dependencies:
321+
- python=3.9.15
322+
- pip
323+
- cloudpickle=1.0.0
324+
- pip:
325+
- numpy
326+
- --trusted-host pypi.ngc.nvidia.com
327+
- --trusted-host example.com:1234
328+
- --trusted-hostinvalid-example.com
329+
- scipy
330+
- pylzy==1.0.0
331+
- serialzy>=1.0.0""");
332+
Assert.assertNotNull(condaYaml);
333+
Assert.assertTrue(
334+
"Trusted host is not in final conda yaml",
335+
condaYaml.contains("--trusted-host pypi.ngc.nvidia.com"));
336+
Assert.assertTrue(
337+
"Trusted host is not in final conda yaml",
338+
condaYaml.contains("--trusted-host example.com:1234"));
339+
Assert.assertEquals(
340+
"Invalid trusted host is in final conda yaml",
341+
2,
342+
Pattern.compile("--trusted-host").matcher(condaYaml).results().count());
343+
}
344+
@Test
345+
public void testPipWithNoDeps() {
346+
String condaYaml = condaPackageRegistry.buildCondaYaml("""
347+
name: default
348+
dependencies:
349+
- python=3.9.15
350+
- pip
351+
- cloudpickle=1.0.0
352+
- pip:
353+
- --no-deps
354+
- numpy
355+
- scipy
356+
- pylzy==1.0.0
357+
- serialzy>=1.0.0""");
358+
Assert.assertNotNull(condaYaml);
359+
Assert.assertTrue("--no-deps is not in conda yaml", condaYaml.contains("--no-deps"));
360+
}
361+
362+
@Test
363+
public void testPipWithoutNoDeps() {
364+
String condaYaml = condaPackageRegistry.buildCondaYaml("""
365+
name: default
366+
dependencies:
367+
- python=3.9.15
368+
- pip
369+
- cloudpickle=1.0.0
370+
- pip:
371+
- numpy
372+
- scipy
373+
- pylzy==1.0.0
374+
- serialzy>=1.0.0""");
375+
Assert.assertNotNull(condaYaml);
376+
Assert.assertFalse("--no-deps is in conda yaml", condaYaml.contains("--no-deps"));
377+
}
378+
247379
@Test
248380
public void testInvalidYaml() {
249381
Assert.assertThrows(Exception.class, () -> condaPackageRegistry.buildCondaYaml("///////"));

0 commit comments

Comments
 (0)