-
Notifications
You must be signed in to change notification settings - Fork 31
Expand file tree
/
Copy pathtraining_information.ts
More file actions
123 lines (116 loc) · 4.29 KB
/
training_information.ts
File metadata and controls
123 lines (116 loc) · 4.29 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import { z } from "zod";
import type { DataType, Network } from "../index.js";
import { Tokenizer } from "../index.js";
const privacySchema = z.object({
// reduce training accuracy and improve privacy.
differentialPrivacy: z
.object({
// maximum weights difference between each epoch, used for differential privacy
clippingRadius: z.number().positive().default(1),
// privacy budget, used to compute the variance of Gaussian noise
epsilon: z.number().positive(),
// small probability that the privacy guarantee may not hold
delta: z.number().gt(0).lt(1),
})
.optional(),
});
const nonLocalNetworkSchema = z
.object({
// minimum number of participants required to train collaboratively
// In decentralized Learning the default is 3, in federated learning it is 2
minNbOfParticipants: z.number().positive().int(),
})
.and(
z.union([
z.object({
aggregationStrategy: z.literal("mean"),
privacy: privacySchema
.transform((o) => (o.differentialPrivacy === undefined ? undefined : o))
.optional(),
}),
z.object({
aggregationStrategy: z.literal("byzantine"),
privacy: z.object({
...privacySchema.shape,
byzantineFaultTolerance: z.object({
// maximum weights difference between each round
clippingRadius: z.number().positive(),
maxIterations: z.number().int().positive().default(1),
beta: z.number().min(0).max(1).default(0.9),
}),
}),
}),
z.object({
aggregationStrategy: z.literal("secure"),
privacy: privacySchema
.transform((o) => (o.differentialPrivacy === undefined ? undefined : o))
.optional(),
// Secure Aggregation: maximum absolute value of a number in a randomly generated share
// default is 100, must be a positive number, check the docs/PRIVACY.md file for more information on significance of maxShareValue selection
maxShareValue: z.number().positive().int().optional().default(100),
}),
]),
);
export namespace TrainingInformation {
export const baseSchema = z.object({
// number of epochs to run training for
epochs: z.number().positive().int(),
// number of epochs between each weight sharing round.
// e.g.if 3 then weights are shared every 3 epochs (in the distributed setting).
roundDuration: z.number().positive().int(),
// fraction of data to keep for validation, note this only works for image data
validationSplit: z.number().min(0).max(1),
// batch size of training data
batchSize: z.number().positive().int(),
// Tensor framework used by the model
tensorBackend: z.enum(["gpt", "tfjs"]),
});
export const dataTypeToSchema = {
image: z.object({
// classes, e.g. if two class of images, one with dogs and one with cats, then we would
// define ['dogs', 'cats'].
LABEL_LIST: z.array(z.string()).min(1),
// height of image to resize to
IMAGE_W: z.number().positive().int(),
// width of image to resize to
IMAGE_H: z.number().positive().int(),
}),
tabular: z.object({
// the columns to be chosen as input data for the model
inputColumns: z.array(z.string()),
// the columns to be predicted by the model
outputColumn: z.string(),
}),
text: z.object({
// should be set with the name of a Transformers.js pre-trained tokenizer, e.g., 'Xenova/gpt2'.
tokenizer: z.instanceof(Tokenizer),
// the maximum length of a input string used as input to a GPT model. It is used during preprocessing to
// truncate strings to a maximum length. The default value is tokenizer.model_max_length
contextLength: z.number().positive().int(),
}),
} satisfies Record<DataType, unknown>;
export const networkToSchema = {
decentralized: z
.object({
scheme: z.literal("decentralized"),
aggregationStrategy: z.literal(["byzantine", "mean"]),
})
.and(nonLocalNetworkSchema),
federated: z
.object({
scheme: z.literal("federated"),
aggregationStrategy: z.literal(["byzantine", "mean", "secure"]),
})
.and(nonLocalNetworkSchema),
local: z.object({
scheme: z.literal("local"),
aggregationStrategy: z.literal("mean"),
}),
} satisfies Record<Network, unknown>;
}
export type TrainingInformation<
D extends DataType,
N extends Network,
> = z.infer<typeof TrainingInformation.baseSchema> &
z.infer<(typeof TrainingInformation.dataTypeToSchema)[D]> &
z.infer<(typeof TrainingInformation.networkToSchema)[N]>;