Skip to content

Commit

Permalink
update readme
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Apr 3, 2023
1 parent d62759a commit 99350dd
Show file tree
Hide file tree
Showing 10 changed files with 21 additions and 26 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@ repos:
rev: 23.1.0
hooks:
- id: black
files: ^(trlx|examples|tests|setup.py)/
files: ^(instructgoose|tests|setup.py)/
11 changes: 5 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,19 @@ InstructGoose

<!-- WARNING: THIS FILE WAS AUTOGENERATED! DO NOT EDIT! -->

[<img src="https://img.shields.io/badge/license-MIT-blue">](https://github.com/vwxyzjn/cleanrl)
[![tests](https://github.com/vwxyzjn/cleanrl/actions/workflows/tests.yaml/badge.svg)](https://github.com/vwxyzjn/cleanrl/actions/workflows/tests.yaml)
[![docs](https://img.shields.io/github/deployments/vwxyzjn/cleanrl/Production?label=docs&logo=vercel.png)](https://docs.cleanrl.dev/)
[<img src="https://img.shields.io/badge/license-MIT-blue">](https://github.com/xrsrke/instructGOOSE)
[![tests](https://github.com/vwxyzjn/cleanrl/actions/workflows/tests.yaml/badge.svg)](https://github.com/xrsrke/instructGOOSE/actions/workflows/tests.yaml)
[![docs](https://img.shields.io/github/deployments/vwxyzjn/cleanrl/Production?label=docs&logo=vercel.png)](https://xrsrke.github.io/instructGOOSE/)
[![Code style:
black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
[![Imports:
isort](https://img.shields.io/badge/%20imports-isort-%231674b1?style=flat&labelColor=ef8336.png)](https://pycqa.github.io/isort/)
[![Open In
Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1HN1t75Jem9jXPzaOQdQO06OXMxinimNx?usp=sharing)
<!-- [![Imports: isort](https://img.shields.io/badge/%20imports-isort-%231674b1?style=flat&labelColor=ef8336)](https://pycqa.github.io/isort/) -->

Paper: InstructGPT - [Training language models to follow instructions
with human feedback](https://arxiv.org/abs/2203.02155)

![image.png](index_files/figure-commonmark/802fce9b-1-image.png)
![image.png](index_files/figure-commonmark/bca1bd5f-1-image.png)

## Install

Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
6 changes: 3 additions & 3 deletions instruct_goose/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@
__all__ = ['Agent', 'AgentObjective']

# %% ../nbs/02_agent.ipynb 4
from typing import Callable, Optional, Tuple
from typing import Callable, Tuple, Optional

import torch
import torch.nn.functional as F
from torch import nn
import torch.nn.functional as F
from torch.distributions import Categorical

from torchtyping import TensorType
from transformers import PreTrainedModel


# %% ../nbs/02_agent.ipynb 6
class Agent(nn.Module):
"The RL-based language model."
Expand Down
5 changes: 2 additions & 3 deletions instruct_goose/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,11 @@
__all__ = ['PairDataset', 'PromptDataset']

# %% ../nbs/01_dataset.ipynb 4
from typing import Callable, Iterable, Tuple
from typing import Callable, Tuple, Iterable

from torch.utils.data import Dataset
from torchtyping import TensorType
from tqdm import tqdm

from torchtyping import TensorType

# %% ../nbs/01_dataset.ipynb 6
class PairDataset(Dataset):
Expand Down
3 changes: 1 addition & 2 deletions instruct_goose/reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
# %% ../nbs/03_reward_model.ipynb 4
import torch
from torch import nn
from torchtyping import TensorType
from transformers import AutoModel, AutoTokenizer

from torchtyping import TensorType

# %% ../nbs/03_reward_model.ipynb 6
class RewardModel(nn.Module):
Expand Down
4 changes: 2 additions & 2 deletions instruct_goose/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@
from typing import Callable, Tuple

import torch
from einops import rearrange
from torchtyping import TensorType
from einops import rearrange

from transformers import PreTrainedModel

from .utils import RLHFConfig


# %% ../nbs/04_trainer.ipynb 6
class RLHFTrainer:
def __init__(
Expand Down
6 changes: 2 additions & 4 deletions instruct_goose/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,10 @@
__all__ = ['load_yaml', 'RLHFConfig', 'create_reference_model', 'ModelConfig', 'TokenizerConfig', 'OptimizerConfig',
'TrainerConfig', 'PPOConfig', 'InstructConfig']

from copy import deepcopy
from dataclasses import dataclass

# %% ../nbs/05_utils.ipynb 3
import yaml

from copy import deepcopy
from dataclasses import dataclass

# %% ../nbs/05_utils.ipynb 4
def load_yaml(config_path):
Expand Down
10 changes: 5 additions & 5 deletions nbs/index.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"[<img src=\"https://img.shields.io/badge/license-MIT-blue\">](https://github.com/vwxyzjn/cleanrl)\n",
"[![tests](https://github.com/vwxyzjn/cleanrl/actions/workflows/tests.yaml/badge.svg)](https://github.com/vwxyzjn/cleanrl/actions/workflows/tests.yaml)\n",
"[![docs](https://img.shields.io/github/deployments/vwxyzjn/cleanrl/Production?label=docs&logo=vercel)](https://docs.cleanrl.dev/)\n",
"[<img src=\"https://img.shields.io/badge/license-MIT-blue\">](https://github.com/xrsrke/instructGOOSE)\n",
"[![tests](https://github.com/vwxyzjn/cleanrl/actions/workflows/tests.yaml/badge.svg)](https://github.com/xrsrke/instructGOOSE/actions/workflows/tests.yaml)\n",
"[![docs](https://img.shields.io/github/deployments/vwxyzjn/cleanrl/Production?label=docs&logo=vercel)](https://xrsrke.github.io/instructGOOSE/)\n",
"[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)\n",
"[![Imports: isort](https://img.shields.io/badge/%20imports-isort-%231674b1?style=flat&labelColor=ef8336)](https://pycqa.github.io/isort/)\n",
"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1HN1t75Jem9jXPzaOQdQO06OXMxinimNx?usp=sharing)\n"
"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1HN1t75Jem9jXPzaOQdQO06OXMxinimNx?usp=sharing)\n",
"<!-- [![Imports: isort](https://img.shields.io/badge/%20imports-isort-%231674b1?style=flat&labelColor=ef8336)](https://pycqa.github.io/isort/) -->"
]
},
{
Expand Down

0 comments on commit 99350dd

Please sign in to comment.