@@ -40,100 +40,112 @@ <h2>概要</h2>
40
40
< h2 > 方法</ h2 >
41
41
< p > …これで終わるわけにも行かないので、方法を書き散らしておきます。< br />
42
42
まずは、Llamaを教師モデルとして自分のモデルを訓練するために、そのプログラムをPerplexityに書いてもらいます。
43
- そのプログラムがこちらです。(動作確認していないので、ご注意)
44
- ``` python: 初回用
43
+ そのプログラムがこちらです。(動作確認していないので、ご注意)</ p >
44
+ < pre > < code class =" language- python" > # 初回用
45
45
import torch
46
46
import torch.nn as nn
47
47
import torch.optim as optim
48
48
from torch.utils.data import Dataset, DataLoader
49
49
from transformers import MllamaForConditionalGeneration, AutoProcessor
50
- from logging import getLogger, Formatter, StreamHandler, DEBUG</ p >
51
- < h1 > ロガーの設定</ h1 >
52
- < p > logger = getLogger(< strong > name</ strong > )
50
+ from logging import getLogger, Formatter, StreamHandler, DEBUG
51
+
52
+ # ロガーの設定
53
+ logger = getLogger(__name__)
53
54
logger.setLevel(DEBUG)
54
55
handler = StreamHandler()
55
56
handler.setLevel(DEBUG)
56
57
formatter = Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
57
58
handler.setFormatter(formatter)
58
- logger.addHandler(handler)</ p >
59
- < h1 > WikipediaDatasetの定義</ h1 >
60
- < p > class WikipediaDataset(Dataset):
61
- def < strong > init</ strong > (self, file_path, max_length=512):
59
+ logger.addHandler(handler)
60
+
61
+ # WikipediaDatasetの定義
62
+ class WikipediaDataset(Dataset):
63
+ def __init__(self, file_path, max_length=512):
62
64
self.file_path = file_path
63
65
self.max_length = max_length
64
- self.data = self.load_annotations()</ p >
65
- < pre > < code > def load_annotations(self):
66
- with open(self.file_path, 'r', encoding='utf-8') as f:
67
- return [line.strip() for line in f if line.strip()]
66
+ self.data = self.load_annotations()
67
+
68
+ def load_annotations(self):
69
+ with open(self.file_path, 'r', encoding='utf-8') as f:
70
+ return [line.strip() for line in f if line.strip()]
68
71
69
- def __len__(self):
70
- return len(self.data)
72
+ def __len__(self):
73
+ return len(self.data)
71
74
72
- def __getitem__(self, idx):
73
- return self.data[idx][:self.max_length]
74
- </ code > </ pre >
75
- < h1 > SmallModelの定義</ h1 >
76
- < p > class SmallModel(nn.Module):
77
- def < strong > init </ strong > (self, input_size, hidden_size, output_size):
78
- super(SmallModel, self).< strong > init </ strong > ()
75
+ def __getitem__(self, idx):
76
+ return self.data[idx][:self.max_length]
77
+
78
+ # SmallModelの定義
79
+ class SmallModel(nn.Module):
80
+ def __init__ (self, input_size, hidden_size, output_size):
81
+ super(SmallModel, self).__init__ ()
79
82
self.fc1 = nn.Linear(input_size, hidden_size)
80
83
self.fc2 = nn.Linear(hidden_size, output_size)
81
- self.relu = nn.ReLU()</ p >
82
- < pre > < code > def forward(self, x):
83
- x = self.relu(self.fc1(x))
84
- x = self.fc2(x)
85
- return x
86
- </ code > </ pre >
87
- < h1 > ローカルのLlama 3.2モデルのロード</ h1 >
88
- < p > model_path = "path/to/your/local/llama3.2/model"
84
+ self.relu = nn.ReLU()
85
+
86
+ def forward(self, x):
87
+ x = self.relu(self.fc1(x))
88
+ x = self.fc2(x)
89
+ return x
90
+
91
+ # ローカルのLlama 3.2モデルのロード
92
+ model_path = "path/to/your/local/llama3.2/model"
89
93
teacher_model = MllamaForConditionalGeneration.from_pretrained(
90
94
model_path,
91
95
torch_dtype=torch.bfloat16,
92
- device_map=" auto"
96
+ device_map=" auto"
93
97
)
94
- processor = AutoProcessor.from_pretrained(model_path)</ p >
95
- < h1 > SmallModelの初期化</ h1 >
96
- < p > input_size = 768 # 入力サイズ(実際のタスクに合わせて調整)
98
+ processor = AutoProcessor.from_pretrained(model_path)
99
+
100
+ # SmallModelの初期化
101
+ input_size = 768 # 入力サイズ(実際のタスクに合わせて調整)
97
102
hidden_size = 256
98
103
output_size = teacher_model.config.vocab_size
99
- student_model = SmallModel(input_size, hidden_size, output_size).to("cuda")</ p >
100
- < h1 > データセットとDataLoaderの準備</ h1 >
101
- < p > dataset = WikipediaDataset("path/to/wiki.txt")
102
- dataloader = DataLoader(dataset, batch_size=2, shuffle=True)</ p >
103
- < h1 > 損失関数とオプティマイザの設定</ h1 >
104
- < p > criterion = nn.KLDivLoss(reduction='batchmean')
105
- optimizer = optim.Adam(student_model.parameters())</ p >
106
- < h1 > 学習ループ</ h1 >
107
- < p > num_epochs = 10
108
- temperature = 2.0</ p >
109
- < p > for epoch in range(num_epochs):
104
+ student_model = SmallModel(input_size, hidden_size, output_size).to("cuda")
105
+
106
+ # データセットとDataLoaderの準備
107
+ dataset = WikipediaDataset("path/to/wiki.txt")
108
+ dataloader = DataLoader(dataset, batch_size=2, shuffle=True)
109
+
110
+ # 損失関数とオプティマイザの設定
111
+ criterion = nn.KLDivLoss(reduction='batchmean')
112
+ optimizer = optim.Adam(student_model.parameters())
113
+
114
+ # 学習ループ
115
+ num_epochs = 10
116
+ temperature = 2.0
117
+
118
+ for epoch in range(num_epochs):
110
119
for batch in dataloader:
111
120
# 入力の処理
112
- inputs = processor(batch, return_tensors="pt", padding=True, truncation=True).to("cuda")</ p >
113
- < pre > < code > # 教師モデルの出力を取得
114
- with torch.no_grad():
115
- teacher_outputs = teacher_model(**inputs).logits
116
-
117
- # 生徒モデルの出力を取得
118
- student_outputs = student_model(inputs.input_ids)
119
-
120
- # 知識蒸留損失の計算
121
- loss = criterion(
122
- nn.functional.log_softmax(student_outputs / temperature, dim=-1),
123
- nn.functional.softmax(teacher_outputs / temperature, dim=-1)
124
- )
125
-
126
- # 逆伝播と最適化
127
- optimizer.zero_grad()
128
- loss.backward()
129
- optimizer.step()
130
-
131
- logger.debug(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}")
121
+ inputs = processor(batch, return_tensors="pt", padding=True, truncation=True).to("cuda")
122
+
123
+ # 教師モデルの出力を取得
124
+ with torch.no_grad():
125
+ teacher_outputs = teacher_model(**inputs).logits
126
+
127
+ # 生徒モデルの出力を取得
128
+ student_outputs = student_model(inputs.input_ids)
129
+
130
+ # 知識蒸留損失の計算
131
+ loss = criterion(
132
+ nn.functional.log_softmax(student_outputs / temperature, dim=-1),
133
+ nn.functional.softmax(teacher_outputs / temperature, dim=-1)
134
+ )
135
+
136
+ # 逆伝播と最適化
137
+ optimizer.zero_grad()
138
+ loss.backward()
139
+ optimizer.step()
140
+
141
+ logger.debug(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}")
142
+
143
+ # モデルの保存
144
+ torch.save(student_model.state_dict(), "small_model.pth")
145
+
146
+ logger.debug("モデルが保存されました。") )
132
147
</ code > </ pre >
133
- < h1 > モデルの保存</ h1 >
134
- < p > torch.save(student_model.state_dict(), "small_model.pth")</ p >
135
- < p > logger.debug("モデルが保存されました。") )</ p >
136
- < pre > < code > ```python:2回目以降用
148
+ < pre > < code class ="language-python "> # 2回目以降用
137
149
import torch
138
150
import torch.nn as nn
139
151
import torch.optim as optim
@@ -276,10 +288,10 @@ <h2>Wikipediaだけじゃ…</h2>
276
288
これで、CC BY-SAとCC BY-NC-SAの組み合わせだったら、死んでいたところでした。 </ p >
277
289
< h2 > 今後</ h2 >
278
290
< p > とりあえず開発途中に書いた記事なので、今後もどんどん更新を入れていきます。</ p >
279
- < button class ="button "
280
- onclick ="location.href = 'https://shizukani-cp.github.io/blog/articles/20241201/' "> 次の記事</ button >
281
- < button class ="button "
291
+ < button class ="button back-next "
282
292
onclick ="location.href = 'https://shizukani-cp.github.io/blog/articles/20241129/' "> 前の記事</ button >
293
+ < button class ="button back-next "
294
+ onclick ="location.href = 'https://shizukani-cp.github.io/blog/articles/20241201/' "> 次の記事</ button >
283
295
</ main >
284
296
< aside id ="sidebar "> </ aside >
285
297
< script src ="../../scripts/articles.json.js "> </ script >
0 commit comments