Skip to content

Commit 5df431c

Browse files
committed
add flex_olmo model
1 parent 8f6e9c9 commit 5df431c

File tree

9 files changed

+843
-2
lines changed

9 files changed

+843
-2
lines changed

mindone/transformers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -539,6 +539,7 @@
539539
FlavaProcessor,
540540
FlavaTextModel,
541541
)
542+
from .models.flex_olmo import FlexOlmoForCausalLM, FlexOlmoModel, FlexOlmoPreTrainedModel
542543
from .models.fnet import (
543544
FNetForMaskedLM,
544545
FNetForMultipleChoice,

mindone/transformers/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@
7979
fastspeech2_conformer,
8080
flaubert,
8181
flava,
82+
flex_olmo,
8283
fnet,
8384
focalnet,
8485
fsmt,

mindone/transformers/models/auto/configuration_auto.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@
100100
("falcon_mamba", "FalconMambaConfig"),
101101
("fastspeech2_conformer", "FastSpeech2ConformerConfig"),
102102
("flava", "FlavaConfig"),
103+
("flex_olmo", "FlexOlmoConfig"),
103104
("fnet", "FNetConfig"),
104105
("focalnet", "FocalNetConfig"),
105106
("fsmt", "FSMTConfig"),
@@ -368,6 +369,7 @@
368369
("falcon_mamba", "FalconMamba"),
369370
("fastspeech2_conformer", "FastSpeech2Conformer"),
370371
("flava", "FLAVA"),
372+
("flex_olmo", "FlexOlmo"),
371373
("fnet", "FNet"),
372374
("focalnet", "FocalNet"),
373375
("fsmt", "FairSeq Machine-Translation"),

mindone/transformers/models/auto/modeling_auto.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@
9595
("falcon_mamba", "FalconMambaModel"),
9696
("fastspeech2_conformer", "FastSpeech2ConformerModel"),
9797
("flava", "FlavaModel"),
98+
("flex_olmo", "FlexOlmoModel"),
9899
("fnet", "FNetModel"),
99100
("focalnet", "FocalNetModel"),
100101
("fsmt", "FSMTModel"),
@@ -433,6 +434,7 @@
433434
("falcon", "FalconForCausalLM"),
434435
("fuyu", "FuyuForCausalLM"),
435436
("falcon_mamba", "FalconMambaForCausalLM"),
437+
("flex_olmo", "FlexOlmoForCausalLM"),
436438
("gemma", "GemmaForCausalLM"),
437439
("gemma2", "Gemma2ForCausalLM"),
438440
("gemma3", "Gemma3ForCausalLM"),
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# coding=utf-8
2+
# Copyright 2025 the HuggingFace Team. All rights reserved.
3+
#
4+
# This code is adapted from https://github.com/huggingface/transformers
5+
# with modifications to run transformers on mindspore.
6+
#
7+
# Licensed under the Apache License, Version 2.0 (the "License");
8+
# you may not use this file except in compliance with the License.
9+
# You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing, software
14+
# distributed under the License is distributed on an "AS IS" BASIS,
15+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
# See the License for the specific language governing permissions and
17+
# limitations under the License.
18+
from .modeling_flex_olmo import *

mindone/transformers/models/flex_olmo/modeling_flex_olmo.py

Lines changed: 679 additions & 0 deletions
Large diffs are not rendered by default.

tests/transformers_tests/causal_lm_tester.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ def all_model_classes(self):
3939

4040
def __init__(
4141
self,
42-
parent,
4342
batch_size=13,
4443
seq_length=7,
4544
is_training=True,
@@ -80,7 +79,6 @@ def __init__(
8079
mamba_chunk_size=16,
8180
):
8281
self._verify_model_attributes()
83-
self.parent = parent
8482
self.batch_size = batch_size
8583
self.seq_length = seq_length
8684
self.is_training = is_training

tests/transformers_tests/models/flex_olmo/__init__.py

Whitespace-only changes.
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
# coding=utf-8
2+
# Copyright 2025 the HuggingFace Team. All rights reserved.
3+
#
4+
# This code is adapted from https://github.com/huggingface/transformers
5+
# with modifications to run transformers on mindspore.
6+
#
7+
# Licensed under the Apache License, Version 2.0 (the "License");
8+
# you may not use this file except in compliance with the License.
9+
# You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing, software
14+
# distributed under the License is distributed on an "AS IS" BASIS,
15+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
# See the License for the specific language governing permissions and
17+
# limitations under the License.
18+
"""Testing suite for the Mindspore FlexOlmo model."""
19+
20+
import numpy as np
21+
import pytest
22+
import torch
23+
from transformers.models.flex_olmo.configuration_flex_olmo import FlexOlmoConfig
24+
25+
import mindspore as ms
26+
27+
from mindone.transformers.models.flex_olmo import FlexOlmoModel
28+
from tests.modeling_test_utils import compute_diffs, generalized_parse_args, get_modules
29+
30+
from ...causal_lm_tester import CausalLMModelTester
31+
32+
DTYPE_AND_THRESHOLDS = {"fp32": 5e-4, "fp16": 5e-3, "bf16": 5e-2}
33+
MODES = [1] # not support graph mode yet
34+
35+
36+
class FlexOlmoModelTester(CausalLMModelTester):
37+
base_model_class = FlexOlmoModel
38+
config_class = FlexOlmoConfig
39+
40+
41+
model_tester = FlexOlmoModelTester()
42+
(
43+
config,
44+
input_ids,
45+
token_type_ids,
46+
input_mask,
47+
sequence_labels,
48+
token_labels,
49+
choice_labels,
50+
) = model_tester.prepare_config_and_inputs()
51+
52+
53+
FLEXOLMO_CASES = [
54+
[
55+
"FlexOlmoModel",
56+
"transformers.FlexOlmoModel",
57+
"mindone.transformers.FlexOlmoModel",
58+
(config,),
59+
{},
60+
(),
61+
{
62+
"input_ids": input_ids,
63+
"attention_mask": input_mask,
64+
},
65+
{
66+
"last_hidden_state": 0,
67+
},
68+
],
69+
]
70+
71+
72+
# transformers need >= 4.41.2
73+
@pytest.mark.parametrize(
74+
"name,pt_module,ms_module,init_args,init_kwargs,inputs_args,inputs_kwargs,outputs_map,dtype,mode",
75+
[
76+
case
77+
+ [
78+
dtype,
79+
]
80+
+ [
81+
mode,
82+
]
83+
for case in FLEXOLMO_CASES
84+
for dtype in DTYPE_AND_THRESHOLDS.keys()
85+
for mode in MODES
86+
],
87+
)
88+
def test_named_modules(
89+
name,
90+
pt_module,
91+
ms_module,
92+
init_args,
93+
init_kwargs,
94+
inputs_args,
95+
inputs_kwargs,
96+
outputs_map,
97+
dtype,
98+
mode,
99+
):
100+
ms.set_context(mode=mode)
101+
102+
(
103+
pt_model,
104+
ms_model,
105+
pt_dtype,
106+
ms_dtype,
107+
) = get_modules(pt_module, ms_module, dtype, *init_args, **init_kwargs)
108+
pt_inputs_args, pt_inputs_kwargs, ms_inputs_args, ms_inputs_kwargs = generalized_parse_args(
109+
pt_dtype, ms_dtype, *inputs_args, **inputs_kwargs
110+
)
111+
112+
# set `hidden_dtype` if requiring, for some modules always compute in float
113+
# precision and require specific `hidden_dtype` to cast before return
114+
with torch.no_grad():
115+
pt_outputs = pt_model(*pt_inputs_args, **pt_inputs_kwargs)
116+
ms_outputs = ms_model(*ms_inputs_args, **ms_inputs_kwargs)
117+
# print("ms:", ms_outputs)
118+
# print("pt:", pt_outputs)
119+
if outputs_map:
120+
pt_outputs_n = []
121+
ms_outputs_n = []
122+
for pt_key, ms_idx in outputs_map.items():
123+
# print("===map", pt_key, ms_idx)
124+
pt_output = getattr(pt_outputs, pt_key)
125+
ms_output = ms_outputs[ms_idx]
126+
if isinstance(pt_output, (list, tuple)):
127+
pt_outputs_n += list(pt_output)
128+
ms_outputs_n += list(ms_output)
129+
else:
130+
pt_outputs_n.append(pt_output)
131+
ms_outputs_n.append(ms_output)
132+
diffs = compute_diffs(pt_outputs_n, ms_outputs_n)
133+
else:
134+
diffs = compute_diffs(pt_outputs, ms_outputs)
135+
136+
THRESHOLD = DTYPE_AND_THRESHOLDS[ms_dtype]
137+
assert (np.array(diffs) < THRESHOLD).all(), (
138+
f"ms_dtype: {ms_dtype}, pt_type:{pt_dtype}, "
139+
f"Outputs({np.array(diffs).tolist()}) has diff bigger than {THRESHOLD}"
140+
)

0 commit comments

Comments
 (0)