Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
olegranmo committed Oct 11, 2024
1 parent b7b89c3 commit d6c7a9d
Showing 1 changed file with 34 additions and 5 deletions.
39 changes: 34 additions & 5 deletions examples/classification/SequenceCountInterpretabilityDemo.py
Original file line number Diff line number Diff line change
@@ -25,7 +25,7 @@ def main(args):
experiment_results = metrics(args)

X_train = np.zeros((args.examples, 1, args.sequence_length, 2), dtype=np.uint32)
Y_train = np.random.randint(0, 3, size=(args.examples), dtype=np.uint32)
Y_train = np.random.randint(0, 4, size=(args.examples), dtype=np.uint32)
for i in range(args.examples):
position_1 = np.random.randint(0, args.sequence_length-4)
position_2 = position_1+1
@@ -67,7 +67,7 @@ def main(args):

X_train[i,0,position_5,0] = 0
X_train[i,0,position_5,1] = 0
else:
elif Y_train[i] == 2:
X_train[i,0,position_1,0] = 1
X_train[i,0,position_1,1] = 0

@@ -82,13 +82,27 @@ def main(args):

X_train[i,0,position_5,0] = 0
X_train[i,0,position_5,1] = 0
else:
X_train[i,0,position_1,0] = 1
X_train[i,0,position_1,1] = 0

X_train[i,0,position_2,0] = 0
X_train[i,0,position_2,1] = 0

X_train[i,0,position_3,0] = 0
X_train[i,0,position_3,1] = 0

X_train[i,0,position_4,0] = 0
X_train[i,0,position_4,1] = 0

X_train[i,0,position_5,0] = 0
X_train[i,0,position_5,1] = 0

if np.random.rand() <= args.noise:
Y_train[i] = np.random.choice(np.setdiff1d([0,1,2], [Y_train[i]]))
Y_train[i] = np.random.choice(np.setdiff1d([0,1,2,3], [Y_train[i]]))

X_test = np.zeros((args.examples//10, 1, args.sequence_length, 2), dtype=np.uint32)
Y_test = np.random.randint(0, 3, size=(args.examples//10), dtype=np.uint32)
Y_test = np.random.randint(0, 4, size=(args.examples//10), dtype=np.uint32)
for i in range(args.examples//10):
position_1 = np.random.randint(0, args.sequence_length-4)
position_2 = position_1+1
@@ -128,14 +142,29 @@ def main(args):
X_test[i,0,position_4,0] = 0
X_test[i,0,position_4,1] = 0

X_test[i,0,position_5,0] = 0
X_test[i,0,position_5,1] = 0
elif Y_test[i] == 2:
X_test[i,0,position_1,0] = 1
X_test[i,0,position_1,1] = 0

X_test[i,0,position_2,0] = 1
X_test[i,0,position_2,1] = 0

X_test[i,0,position_3,0] = 0
X_test[i,0,position_3,1] = 0

X_test[i,0,position_4,0] = 0
X_test[i,0,position_4,1] = 0

X_test[i,0,position_5,0] = 0
X_test[i,0,position_5,1] = 0
else:
X_test[i,0,position_1,0] = 1
X_test[i,0,position_1,1] = 0

X_test[i,0,position_2,0] = 0
X_test[i,0,position_2,1] = 1
X_test[i,0,position_2,1] = 0

X_test[i,0,position_3,0] = 0
X_test[i,0,position_3,1] = 0

0 comments on commit d6c7a9d

Please sign in to comment.