Skip to content

Commit

Permalink
Fixing various import issues.
Browse files Browse the repository at this point in the history
  • Loading branch information
pbloem committed Aug 3, 2021
1 parent f5351a5 commit 557e767
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 15 deletions.
6 changes: 3 additions & 3 deletions experiments/classify.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from former import util

from util import d, here
from former.util import d, here

import torch
from torch import nn
from torch.autograd import Variable
import torch.nn.functional as F

from torchtext import data, datasets, vocab
# from torchtext import data, datasets, vocab
from torchtext.legacy import data, datasets, vocab

import numpy as np

Expand Down
6 changes: 3 additions & 3 deletions former/modules.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .util import mask_, d
from .util import mask_, d, slice_diag

import torch
from torch import nn
Expand Down Expand Up @@ -445,14 +445,14 @@ def forward(self, x):
assert dot_tt.size()== (b*h, t, t), f'{dot_tt.size()}'

dot_tp = torch.einsum('bis, bjs -> bij', queries, keys_pos) # -- token with position
dot_tp = .util.slice_diag(dot_tp, l=t)
dot_tp = slice_diag(dot_tp, l=t)
assert dot_tp.size() == (b*h, t, t), f'{dot_tp.size()}'

dot_pt = torch.einsum('bis, bjs -> bij', parma, keys) # -- position with token
assert dot_pt.size() == (b*h, t, t), f'{dot_pt.size()}'

dot_pp = torch.einsum('bis, bjs -> bij', parmb, keys_pos) # -- pos with pos
dot_pp = .util.slice_diag(dot_pp, l=t)
dot_pp = slice_diag(dot_pp, l=t)
assert dot_pp.size() == (b*h, t, t), f'{dot_pp.size()}'

dot = dot_tt + dot_tp + dot_pt + dot_pp
Expand Down
16 changes: 8 additions & 8 deletions former/util/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import torch.nn.functional as F
import torch.distributions as dist

import transformers as trf
from torch.utils.tensorboard import SummaryWriter

import numpy as np
Expand Down Expand Up @@ -179,15 +178,17 @@ def slice_diag(matrix, l, dv=None):
dv = d(matrix)

h, w = matrix.size(-2), matrix.size(-1)
assert w == 2 * l -1, f'{(h, w)=} {l=}'

assert w == 2 * l -1, f'(h, w)= {(h, w)}, l={l}'

rest = matrix.size()[:-2]

matrix = matrix.view(-1, h, w)
b, h, w = matrix.size()

result = matrix.view(b, -1)
result = torch.cat([result, torch.zeros(b, l, device=dv)], dim=1)
assert result.size() == (b, 2 * l * l), f'{result.size()=}'
assert result.size() == (b, 2 * l * l), f'result.size() {result.size()}'

result = result.view(b, l, 2*l)
result = result[:, :, :l]
Expand All @@ -200,7 +201,7 @@ def slice_diag(matrix, l, dv=None):
LOGE2 = math.log(2.0)

def compute_compression(model, data, context, batch_size, verbose=False,
tbw:SummaryWriter=None, tok:trf.GPT2Tokenizer=None, skip=0):
tbw:SummaryWriter=None, tok=None, skip=0):


"""
Expand All @@ -224,8 +225,10 @@ def compute_compression(model, data, context, batch_size, verbose=False,
# need to shift the start/end indices ahead by one token.
#
# After we pass the batch through the model, we look at only the probabilities predicted for the last token.

target_indices = []
i, ic = 0, 0

for current in tqdm.trange(skip, data.size(0)) if verbose else range(skip, data.size(0)):

# `current` is the character which we will ultimately predict
Expand Down Expand Up @@ -301,10 +304,7 @@ def compute_compression(model, data, context, batch_size, verbose=False,
if isinstance(bits, torch.Tensor):
bits = bits.item()

if tok is not None:
return bits, ic # total nr of bits used, total nr of characters seen
else:
return bits # total nr of bits used
return bits # total nr of bits used

def estimate_compression(model, data, nsamples, context, batch_size, verbose=False):
"""
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
'torch',
'tqdm',
'numpy',
'torchtext'
'torchtext',
'tensorboard'
],
zip_safe=False)

0 comments on commit 557e767

Please sign in to comment.