Skip to content

Commit

Permalink
Merge pull request #108 from FengZiYjun/trainer
Browse files Browse the repository at this point in the history
FastNLP v0.2
  • Loading branch information
xpqiu authored Dec 7, 2018
2 parents 15262bd + db0a789 commit 1b477a9
Show file tree
Hide file tree
Showing 60 changed files with 5,432 additions and 1,731 deletions.
47 changes: 33 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,39 @@
![Hex.pm](https://img.shields.io/hexpm/l/plug.svg)
[![Documentation Status](https://readthedocs.org/projects/fastnlp/badge/?version=latest)](http://fastnlp.readthedocs.io/?badge=latest)

fastNLP is a modular Natural Language Processing system based on PyTorch, for fast development of NLP tools. It divides the NLP model based on deep learning into different modules. These modules fall into 4 categories: encoder, interaction, aggregation and decoder, while each category contains different implemented modules. Encoder modules encode the input into some abstract representation, interaction modules make the information in the representation interact with each other, aggregation modules aggregate and reduce information, and decoder modules decode the representation into the output. Most current NLP models could be built on these modules, which vastly simplifies the process of developing NLP models. The architecture of fastNLP is as the figure below:
FastNLP is a modular Natural Language Processing system based on PyTorch, built for fast development of NLP models.

![](https://github.com/fastnlp/fastNLP/raw/master/docs/source/figures/procedures.PNG)
![](https://github.com/fastnlp/fastNLP/raw/master/docs/source/figures/text_classification.png)
A deep learning NLP model is the composition of three types of modules:
<table>
<tr>
<td><b> module type </b></td>
<td><b> functionality </b></td>
<td><b> example </b></td>
</tr>
<tr>
<td> encoder </td>
<td> encode the input into some abstract representation </td>
<td> embedding, RNN, CNN, transformer
</tr>
<tr>
<td> aggregator </td>
<td> aggregate and reduce information </td>
<td> self-attention, max-pooling </td>
</tr>
<tr>
<td> decoder </td>
<td> decode the representation into the output </td>
<td> MLP, CRF </td>
</tr>

For example:

![](docs/source/figures/text_classification.png)

## Requirements

- numpy>=1.14.2
- torch>=0.4.0
- torchvision>=0.1.8
- tensorboardX


Expand All @@ -39,12 +62,12 @@ pip install fastNLP
<td> an open-source NLP library </td>
</tr>
<tr>
<td><b> fastNLP.core </b></td>
<td> trainer, tester, predictor </td>
<td><b> fastNLP.api </b></td>
<td> APIs for end-to-end prediction </td>
</tr>
<tr>
<td><b> fastNLP.loader </b></td>
<td> all kinds of loaders/readers </td>
<td><b> fastNLP.core </b></td>
<td> data representation & train/test presedure </td>
</tr>
<tr>
<td><b> fastNLP.models </b></td>
Expand All @@ -55,11 +78,7 @@ pip install fastNLP
<td> a collection of PyTorch sub-models/components/wheels </td>
</tr>
<tr>
<td><b> fastNLP.saver </b></td>
<td> all kinds of savers/writers </td>
</tr>
<tr>
<td><b> fastNLP.fastnlp </b></td>
<td> a high-level interface for prediction </td>
<td><b> fastNLP.io </b></td>
<td> readers & savers </td>
</tr>
</table>
3 changes: 2 additions & 1 deletion docs/quick_tutorial.md
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
# FastNLP Quick Tutorial
# FastNLP Quick Tutorial

Binary file modified docs/source/figures/text_classification.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
8 changes: 2 additions & 6 deletions fastNLP/api/model_zoo.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import torch

import hashlib
import os
import re
import shutil
import sys
import tempfile

import torch

try:
from requests.utils import urlparse
from requests import get as urlopen
Expand Down Expand Up @@ -132,7 +132,3 @@ def __exit__(self, exc_type, exc_val, exc_tb):

sys.stderr.write('\n')


if __name__ == '__main__':
pipeline = load_url('http://10.141.208.102:5000/file/download/infer_context-4e86fd93.pkl', model_dir='.')
print(type(pipeline))
38 changes: 25 additions & 13 deletions fastNLP/api/processor.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import torch
from collections import defaultdict
import re
from collections import defaultdict

import torch

from fastNLP.core.dataset import DataSet
from fastNLP.core.vocabulary import Vocabulary
from fastNLP.core.batch import Batch
from fastNLP.core.dataset import DataSet
from fastNLP.core.sampler import SequentialSampler
from fastNLP.core.vocabulary import Vocabulary


class Processor:
class Processor(object):
def __init__(self, field_name, new_added_field_name):
self.field_name = field_name
if new_added_field_name is None:
Expand All @@ -17,7 +18,7 @@ def __init__(self, field_name, new_added_field_name):
self.new_added_field_name = new_added_field_name

def process(self, *args, **kwargs):
pass
raise NotImplementedError

def __call__(self, *args, **kwargs):
return self.process(*args, **kwargs)
Expand Down Expand Up @@ -132,27 +133,29 @@ def process(self, dataset):


class IndexerProcessor(Processor):
def __init__(self, vocab, field_name, new_added_field_name, delete_old_field=False):
def __init__(self, vocab, field_name, new_added_field_name, delete_old_field=False, is_input=True):

assert isinstance(vocab, Vocabulary), "Only Vocabulary class is allowed, not {}.".format(type(vocab))

super(IndexerProcessor, self).__init__(field_name, new_added_field_name)
self.vocab = vocab
self.delete_old_field = delete_old_field
self.is_input = is_input

def set_vocab(self, vocab):
assert isinstance(vocab, Vocabulary), "Only Vocabulary class is allowed, not {}.".format(type(vocab))

self.vocab = vocab

def process(self, dataset):
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset))
assert isinstance(dataset, DataSet), "Only DataSet class is allowed, not {}.".format(type(dataset))
for ins in dataset:
tokens = ins[self.field_name]
index = [self.vocab.to_index(token) for token in tokens]
ins[self.new_added_field_name] = index

dataset._set_need_tensor(**{self.new_added_field_name: True})
if self.is_input:
dataset.set_input(self.new_added_field_name)

if self.delete_old_field:
dataset.delete_field(self.field_name)
Expand All @@ -161,6 +164,9 @@ def process(self, dataset):


class VocabProcessor(Processor):
"""Build vocabulary with a field in the data set.
"""
def __init__(self, field_name):
super(VocabProcessor, self).__init__(field_name, None)
self.vocab = Vocabulary()
Expand All @@ -178,17 +184,20 @@ def get_vocab(self):


class SeqLenProcessor(Processor):
def __init__(self, field_name, new_added_field_name='seq_lens'):
def __init__(self, field_name, new_added_field_name='seq_lens', is_input=True):
super(SeqLenProcessor, self).__init__(field_name, new_added_field_name)
self.is_input = is_input

def process(self, dataset):
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset))
for ins in dataset:
length = len(ins[self.field_name])
ins[self.new_added_field_name] = length
dataset._set_need_tensor(**{self.new_added_field_name: True})
if self.is_input:
dataset.set_input(self.new_added_field_name)
return dataset


class ModelProcessor(Processor):
def __init__(self, model, seq_len_field_name='seq_lens', batch_size=32):
"""
Expand Down Expand Up @@ -238,6 +247,7 @@ def set_model_device(self, device):
device = torch.device(device)
self.model.to(device)


class Index2WordProcessor(Processor):
def __init__(self, vocab, field_name, new_added_field_name):
super(Index2WordProcessor, self).__init__(field_name, new_added_field_name)
Expand All @@ -251,26 +261,28 @@ def process(self, dataset):


class SetTensorProcessor(Processor):
# TODO: remove it. It is strange.
def __init__(self, field_dict, default=False):
super(SetTensorProcessor, self).__init__(None, None)
self.field_dict = field_dict
self.default = default

def process(self, dataset):
set_dict = {name: self.default for name in dataset.get_fields().keys()}
set_dict = {name: self.default for name in dataset.get_all_fields().keys()}
set_dict.update(self.field_dict)
dataset._set_need_tensor(**set_dict)
return dataset


class SetIsTargetProcessor(Processor):
# TODO; remove it.
def __init__(self, field_dict, default=False):
super(SetIsTargetProcessor, self).__init__(None, None)
self.field_dict = field_dict
self.default = default

def process(self, dataset):
set_dict = {name: self.default for name in dataset.get_fields().keys()}
set_dict = {name: self.default for name in dataset.get_all_fields().keys()}
set_dict.update(self.field_dict)
dataset.set_target(**set_dict)
return dataset
10 changes: 6 additions & 4 deletions fastNLP/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from .batch import Batch
from .dataset import DataSet
# from .dataset import DataSet
from .fieldarray import FieldArray
from .instance import Instance
from .metrics import Evaluator, ClassifyEvaluator, SNLIEvaluator, SeqLabelEvaluator
from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward
from .metrics import AccuracyMetric
from .optimizer import Optimizer, SGD, Adam
from .sampler import SequentialSampler, BucketSampler, RandomSampler, BaseSampler
from .tester import Tester
from .trainer import Trainer
from .vocabulary import Vocabulary
from .optimizer import Optimizer
from .loss import Loss
from ..io.dataset_loader import DataSet

17 changes: 15 additions & 2 deletions fastNLP/core/batch.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
import torch


Expand Down Expand Up @@ -25,6 +26,7 @@ def __init__(self, dataset, batch_size, sampler, as_numpy=False):
self.as_numpy = as_numpy
self.idx_list = None
self.curidx = 0
self.num_batches = len(dataset)//batch_size + int(len(dataset)%batch_size!=0)

def __iter__(self):
self.idx_list = self.sampler(self.dataset)
Expand All @@ -41,11 +43,11 @@ def __next__(self):

indices = self.idx_list[self.curidx:endidx]

for field_name, field in self.dataset.get_fields().items():
for field_name, field in self.dataset.get_all_fields().items():
if field.is_target or field.is_input:
batch = field.get(indices)
if not self.as_numpy:
batch = torch.from_numpy(batch)
batch = to_tensor(batch, field.dtype)
if field.is_target:
batch_y[field_name] = batch
if field.is_input:
Expand All @@ -54,3 +56,14 @@ def __next__(self):
self.curidx = endidx

return batch_x, batch_y

def __len__(self):
return self.num_batches


def to_tensor(batch, dtype):
if dtype in (int, np.int8, np.int16, np.int32, np.int64):
batch = torch.LongTensor(batch)
if dtype in (float, np.float32, np.float64):
batch = torch.FloatTensor(batch)
return batch
Loading

0 comments on commit 1b477a9

Please sign in to comment.