diff --git a/.github/workflows/pre-commit.yaml b/.github/workflows/pre-commit.yaml
new file mode 100644
index 0000000..fb4ded7
--- /dev/null
+++ b/.github/workflows/pre-commit.yaml
@@ -0,0 +1,24 @@
+name: pre-commit
+ pull_request:
+ push:
+ branches: [main]
+ check_and_test:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v3
+ - uses: actions/setup-python@v4
+ id: ko-sentence-transformers
+ with:
+ python-version: '3.10'
+ cache: 'pip'
+ - name: pre-commit
+ run: |
+ pip install --upgrade pip
+ pip install -U pre-commit
+ pre-commit install --install-hooks
+ pre-commit run -a
\ No newline at end of file
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
new file mode 100644
index 0000000..2284a53
--- /dev/null
+++ b/.pre-commit-config.yaml
@@ -0,0 +1,30 @@
+exclude: ^(legacy|bin)
+ - repo: https://github.com/pre-commit/pre-commit-hooks
+ rev: v4.0.1
+ hooks:
+ - id: end-of-file-fixer
+ types: [python]
+ - id: trailing-whitespace
+ types: [python]
+ - id: mixed-line-ending
+ types: [python]
+ - id: check-added-large-files
+ args: [--maxkb=4096]
+ - repo: https://github.com/psf/black
+ rev: 22.3.0
+ hooks:
+ - id: black
+ args: ["--line-length", "120"]
+ - repo: https://github.com/pycqa/isort
+ rev: 5.12.0
+ hooks:
+ - id: isort
+ name: isort (python)
+ args: ["--profile", "black", "-l", "120"]
+ - repo: https://github.com/pycqa/flake8.git
+ rev: 6.0.0
+ hooks:
+ - id: flake8
+ types: [python]
+ args: ["--max-line-length", "120", "--ignore", "F811,F841,E203,E402,E712,W503"]
\ No newline at end of file
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000..53d7025
--- /dev/null
@@ -0,0 +1,427 @@
+Attribution-ShareAlike 4.0 International
+Creative Commons Corporation ("Creative Commons") is not a law firm and
+does not provide legal services or legal advice. Distribution of
+Creative Commons public licenses does not create a lawyer-client or
+other relationship. Creative Commons makes its licenses and related
+information available on an "as-is" basis. Creative Commons gives no
+warranties regarding its licenses, any material licensed under their
+terms and conditions, or any related information. Creative Commons
+disclaims all liability for damages resulting from their use to the
+fullest extent possible.
+Using Creative Commons Public Licenses
+Creative Commons public licenses provide a standard set of terms and
+conditions that creators and other rights holders may use to share
+original works of authorship and other material subject to copyright
+and certain other rights specified in the public license below. The
+following considerations are for informational purposes only, are not
+exhaustive, and do not form part of our licenses.
+ Considerations for licensors: Our public licenses are
+ intended for use by those authorized to give the public
+ permission to use material in ways otherwise restricted by
+ copyright and certain other rights. Our licenses are
+ irrevocable. Licensors should read and understand the terms
+ and conditions of the license they choose before applying it.
+ Licensors should also secure all rights necessary before
+ applying our licenses so that the public can reuse the
+ material as expected. Licensors should clearly mark any
+ material not subject to the license. This includes other CC-
+ licensed material, or material used under an exception or
+ limitation to copyright. More considerations for licensors:
+ wiki.creativecommons.org/Considerations_for_licensors
+ Considerations for the public: By using one of our public
+ licenses, a licensor grants the public permission to use the
+ licensed material under specified terms and conditions. If
+ the licensor's permission is not necessary for any reason--for
+ example, because of any applicable exception or limitation to
+ copyright--then that use is not regulated by the license. Our
+ licenses grant only permissions under copyright and certain
+ other rights that a licensor has authority to grant. Use of
+ the licensed material may still be restricted for other
+ reasons, including because others have copyright or other
+ rights in the material. A licensor may make special requests,
+ such as asking that all changes be marked or described.
+ Although not required by our licenses, you are encouraged to
+ respect those requests where reasonable. More considerations
+ for the public:
+ wiki.creativecommons.org/Considerations_for_licensees
+Creative Commons Attribution-ShareAlike 4.0 International Public
+By exercising the Licensed Rights (defined below), You accept and agree
+to be bound by the terms and conditions of this Creative Commons
+Attribution-ShareAlike 4.0 International Public License ("Public
+License"). To the extent this Public License may be interpreted as a
+contract, You are granted the Licensed Rights in consideration of Your
+acceptance of these terms and conditions, and the Licensor grants You
+such rights in consideration of benefits the Licensor receives from
+making the Licensed Material available under these terms and
+Section 1 -- Definitions.
+ a. Adapted Material means material subject to Copyright and Similar
+ Rights that is derived from or based upon the Licensed Material
+ and in which the Licensed Material is translated, altered,
+ arranged, transformed, or otherwise modified in a manner requiring
+ permission under the Copyright and Similar Rights held by the
+ Licensor. For purposes of this Public License, where the Licensed
+ Material is a musical work, performance, or sound recording,
+ Adapted Material is always produced where the Licensed Material is
+ synched in timed relation with a moving image.
+ b. Adapter's License means the license You apply to Your Copyright
+ and Similar Rights in Your contributions to Adapted Material in
+ accordance with the terms and conditions of this Public License.
+ c. BY-SA Compatible License means a license listed at
+ creativecommons.org/compatiblelicenses, approved by Creative
+ Commons as essentially the equivalent of this Public License.
+ d. Copyright and Similar Rights means copyright and/or similar rights
+ closely related to copyright including, without limitation,
+ performance, broadcast, sound recording, and Sui Generis Database
+ Rights, without regard to how the rights are labeled or
+ categorized. For purposes of this Public License, the rights
+ specified in Section 2(b)(1)-(2) are not Copyright and Similar
+ Rights.
+ e. Effective Technological Measures means those measures that, in the
+ absence of proper authority, may not be circumvented under laws
+ fulfilling obligations under Article 11 of the WIPO Copyright
+ Treaty adopted on December 20, 1996, and/or similar international
+ agreements.
+ f. Exceptions and Limitations means fair use, fair dealing, and/or
+ any other exception or limitation to Copyright and Similar Rights
+ that applies to Your use of the Licensed Material.
+ g. License Elements means the license attributes listed in the name
+ of a Creative Commons Public License. The License Elements of this
+ Public License are Attribution and ShareAlike.
+ h. Licensed Material means the artistic or literary work, database,
+ or other material to which the Licensor applied this Public
+ License.
+ i. Licensed Rights means the rights granted to You subject to the
+ terms and conditions of this Public License, which are limited to
+ all Copyright and Similar Rights that apply to Your use of the
+ Licensed Material and that the Licensor has authority to license.
+ j. Licensor means the individual(s) or entity(ies) granting rights
+ under this Public License.
+ k. Share means to provide material to the public by any means or
+ process that requires permission under the Licensed Rights, such
+ as reproduction, public display, public performance, distribution,
+ dissemination, communication, or importation, and to make material
+ available to the public including in ways that members of the
+ public may access the material from a place and at a time
+ individually chosen by them.
+ l. Sui Generis Database Rights means rights other than copyright
+ resulting from Directive 96/9/EC of the European Parliament and of
+ the Council of 11 March 1996 on the legal protection of databases,
+ as amended and/or succeeded, as well as other essentially
+ equivalent rights anywhere in the world.
+ m. You means the individual or entity exercising the Licensed Rights
+ under this Public License. Your has a corresponding meaning.
+Section 2 -- Scope.
+ a. License grant.
+ 1. Subject to the terms and conditions of this Public License,
+ the Licensor hereby grants You a worldwide, royalty-free,
+ non-sublicensable, non-exclusive, irrevocable license to
+ exercise the Licensed Rights in the Licensed Material to:
+ a. reproduce and Share the Licensed Material, in whole or
+ in part; and
+ b. produce, reproduce, and Share Adapted Material.
+ 2. Exceptions and Limitations. For the avoidance of doubt, where
+ Exceptions and Limitations apply to Your use, this Public
+ License does not apply, and You do not need to comply with
+ its terms and conditions.
+ 3. Term. The term of this Public License is specified in Section
+ 6(a).
+ 4. Media and formats; technical modifications allowed. The
+ Licensor authorizes You to exercise the Licensed Rights in
+ all media and formats whether now known or hereafter created,
+ and to make technical modifications necessary to do so. The
+ Licensor waives and/or agrees not to assert any right or
+ authority to forbid You from making technical modifications
+ necessary to exercise the Licensed Rights, including
+ technical modifications necessary to circumvent Effective
+ Technological Measures. For purposes of this Public License,
+ simply making modifications authorized by this Section 2(a)
+ (4) never produces Adapted Material.
+ 5. Downstream recipients.
+ a. Offer from the Licensor -- Licensed Material. Every
+ recipient of the Licensed Material automatically
+ receives an offer from the Licensor to exercise the
+ Licensed Rights under the terms and conditions of this
+ Public License.
+ b. Additional offer from the Licensor -- Adapted Material.
+ Every recipient of Adapted Material from You
+ automatically receives an offer from the Licensor to
+ exercise the Licensed Rights in the Adapted Material
+ under the conditions of the Adapter's License You apply.
+ c. No downstream restrictions. You may not offer or impose
+ any additional or different terms or conditions on, or
+ apply any Effective Technological Measures to, the
+ Licensed Material if doing so restricts exercise of the
+ Licensed Rights by any recipient of the Licensed
+ Material.
+ 6. No endorsement. Nothing in this Public License constitutes or
+ may be construed as permission to assert or imply that You
+ are, or that Your use of the Licensed Material is, connected
+ with, or sponsored, endorsed, or granted official status by,
+ the Licensor or others designated to receive attribution as
+ provided in Section 3(a)(1)(A)(i).
+ b. Other rights.
+ 1. Moral rights, such as the right of integrity, are not
+ licensed under this Public License, nor are publicity,
+ privacy, and/or other similar personality rights; however, to
+ the extent possible, the Licensor waives and/or agrees not to
+ assert any such rights held by the Licensor to the limited
+ extent necessary to allow You to exercise the Licensed
+ Rights, but not otherwise.
+ 2. Patent and trademark rights are not licensed under this
+ Public License.
+ 3. To the extent possible, the Licensor waives any right to
+ collect royalties from You for the exercise of the Licensed
+ Rights, whether directly or through a collecting society
+ under any voluntary or waivable statutory or compulsory
+ licensing scheme. In all other cases the Licensor expressly
+ reserves any right to collect such royalties.
+Section 3 -- License Conditions.
+Your exercise of the Licensed Rights is expressly made subject to the
+following conditions.
+ a. Attribution.
+ 1. If You Share the Licensed Material (including in modified
+ form), You must:
+ a. retain the following if it is supplied by the Licensor
+ with the Licensed Material:
+ i. identification of the creator(s) of the Licensed
+ Material and any others designated to receive
+ attribution, in any reasonable manner requested by
+ the Licensor (including by pseudonym if
+ designated);
+ ii. a copyright notice;
+ iii. a notice that refers to this Public License;
+ iv. a notice that refers to the disclaimer of
+ warranties;
+ v. a URI or hyperlink to the Licensed Material to the
+ extent reasonably practicable;
+ b. indicate if You modified the Licensed Material and
+ retain an indication of any previous modifications; and
+ c. indicate the Licensed Material is licensed under this
+ Public License, and include the text of, or the URI or
+ hyperlink to, this Public License.
+ 2. You may satisfy the conditions in Section 3(a)(1) in any
+ reasonable manner based on the medium, means, and context in
+ which You Share the Licensed Material. For example, it may be
+ reasonable to satisfy the conditions by providing a URI or
+ hyperlink to a resource that includes the required
+ information.
+ 3. If requested by the Licensor, You must remove any of the
+ information required by Section 3(a)(1)(A) to the extent
+ reasonably practicable.
+ b. ShareAlike.
+ In addition to the conditions in Section 3(a), if You Share
+ Adapted Material You produce, the following conditions also apply.
+ 1. The Adapter's License You apply must be a Creative Commons
+ license with the same License Elements, this version or
+ later, or a BY-SA Compatible License.
+ 2. You must include the text of, or the URI or hyperlink to, the
+ Adapter's License You apply. You may satisfy this condition
+ in any reasonable manner based on the medium, means, and
+ context in which You Share Adapted Material.
+ 3. You may not offer or impose any additional or different terms
+ or conditions on, or apply any Effective Technological
+ Measures to, Adapted Material that restrict exercise of the
+ rights granted under the Adapter's License You apply.
+Section 4 -- Sui Generis Database Rights.
+Where the Licensed Rights include Sui Generis Database Rights that
+apply to Your use of the Licensed Material:
+ a. for the avoidance of doubt, Section 2(a)(1) grants You the right
+ to extract, reuse, reproduce, and Share all or a substantial
+ portion of the contents of the database;
+ b. if You include all or a substantial portion of the database
+ contents in a database in which You have Sui Generis Database
+ Rights, then the database in which You have Sui Generis Database
+ Rights (but not its individual contents) is Adapted Material,
+ including for purposes of Section 3(b); and
+ c. You must comply with the conditions in Section 3(a) if You Share
+ all or a substantial portion of the contents of the database.
+For the avoidance of doubt, this Section 4 supplements and does not
+replace Your obligations under this Public License where the Licensed
+Rights include other Copyright and Similar Rights.
+Section 5 -- Disclaimer of Warranties and Limitation of Liability.
+ c. The disclaimer of warranties and limitation of liability provided
+ above shall be interpreted in a manner that, to the extent
+ possible, most closely approximates an absolute disclaimer and
+ waiver of all liability.
+Section 6 -- Term and Termination.
+ a. This Public License applies for the term of the Copyright and
+ Similar Rights licensed here. However, if You fail to comply with
+ this Public License, then Your rights under this Public License
+ terminate automatically.
+ b. Where Your right to use the Licensed Material has terminated under
+ Section 6(a), it reinstates:
+ 1. automatically as of the date the violation is cured, provided
+ it is cured within 30 days of Your discovery of the
+ violation; or
+ 2. upon express reinstatement by the Licensor.
+ For the avoidance of doubt, this Section 6(b) does not affect any
+ right the Licensor may have to seek remedies for Your violations
+ of this Public License.
+ c. For the avoidance of doubt, the Licensor may also offer the
+ Licensed Material under separate terms or conditions or stop
+ distributing the Licensed Material at any time; however, doing so
+ will not terminate this Public License.
+ d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
+ License.
+Section 7 -- Other Terms and Conditions.
+ a. The Licensor shall not be bound by any additional or different
+ terms or conditions communicated by You unless expressly agreed.
+ b. Any arrangements, understandings, or agreements regarding the
+ Licensed Material not stated herein are separate from and
+ independent of the terms and conditions of this Public License.
+Section 8 -- Interpretation.
+ a. For the avoidance of doubt, this Public License does not, and
+ shall not be interpreted to, reduce, limit, restrict, or impose
+ conditions on any use of the Licensed Material that could lawfully
+ be made without permission under this Public License.
+ b. To the extent possible, if any provision of this Public License is
+ deemed unenforceable, it shall be automatically reformed to the
+ minimum extent necessary to make it enforceable. If the provision
+ cannot be reformed, it shall be severed from this Public License
+ without affecting the enforceability of the remaining terms and
+ conditions.
+ c. No term or condition of this Public License will be waived and no
+ failure to comply consented to unless expressly agreed to by the
+ Licensor.
+ d. Nothing in this Public License constitutes or may be interpreted
+ as a limitation upon, or waiver of, any privileges and immunities
+ that apply to the Licensor or You, including from the legal
+ processes of any jurisdiction or authority.
+Creative Commons is not a party to its public licenses.
+Notwithstanding, Creative Commons may elect to apply one of its public
+licenses to material it publishes and in those instances will be
+considered the “Licensor.” The text of the Creative Commons public
+licenses is dedicated to the public domain under the CC0 Public Domain
+Dedication. Except for the limited purpose of indicating that material
+is shared under a Creative Commons public license or as otherwise
+permitted by the Creative Commons policies published at
+creativecommons.org/policies, Creative Commons does not authorize the
+use of the trademark "Creative Commons" or any other trademark or logo
+of Creative Commons without its prior written consent including,
+without limitation, in connection with any unauthorized modifications
+to any of its public licenses or any other arrangements,
+understandings, or agreements concerning use of licensed material. For
+the avoidance of doubt, this paragraph does not form part of the public
+Creative Commons may be contacted at creativecommons.org.
\ No newline at end of file
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..85f8e7b
--- /dev/null
+++ b/README.md
@@ -0,0 +1,238 @@
+# kf-deberta-multitask
+kakaobank의 [kf-deberta-base](https://huggingface.co/kakaobank/kf-deberta-base) 모델을 KorNLI, KorSTS 데이터셋으로 파인튜닝한 모델입니다.
+[jhgan00/ko-sentence-transformers](https://github.com/jhgan00/ko-sentence-transformers) 코드를 기반으로 일부 수정하여 진행하였습니다.
+## KorSTS Benchmark
+- [jhgan00/ko-sentence-transformers](https://github.com/jhgan00/ko-sentence-transformers#korsts-benchmarks)의 결과를 참고하여 재작성하였습니다.
+- 학습 및 성능 평가 과정은 `training_*.py`, `benchmark.py` 에서 확인할 수 있습니다.
+- 학습된 모델은 허깅페이스 모델 허브에 공개되어 있습니다.
+## Examples
+아래는 임베딩 벡터를 통해 가장 유사한 문장을 찾는 예시입니다.
+더 많은 예시는 [sentence-transformers 문서](https://www.sbert.net/index.html)를 참고해주세요.
+from transformers import AutoTokenizer, AutoModel
+import torch
+# Mean Pooling - Take attention mask into account for correct averaging
+def mean_pooling(model_output, attention_mask):
+ token_embeddings = model_output[0] # First element of model_output contains all token embeddings
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
+# Sentences we want sentence embeddings for
+sentences = ["경제 전문가가 금리 인하에 대한 예측을 하고 있다.", "주식 시장에서 한 투자자가 주식을 매수한다."]
+# Load model from HuggingFace Hub
+tokenizer = AutoTokenizer.from_pretrained("upskyy/kf-deberta-multitask")
+model = AutoModel.from_pretrained("upskyy/kf-deberta-multitask")
+# Tokenize sentences
+encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
+# Compute token embeddings
+with torch.no_grad():
+ model_output = model(**encoded_input)
+# Perform pooling. In this case, mean pooling.
+sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
+print("Sentence embeddings:")
+from sentence_transformers import SentenceTransformer, util
+import numpy as np
+# Sentence transformer model for financial domain
+embedder = SentenceTransformer("upskyy/kf-deberta-multitask")
+# Financial domain corpus
+corpus = [
+ "주식 시장에서 한 투자자가 주식을 매수한다.",
+ "은행에서 예금을 만기로 인출하는 고객이 있다.",
+ "금융 전문가가 새로운 투자 전략을 개발하고 있다.",
+ "증권사에서 주식 포트폴리오를 관리하는 팀이 있다.",
+ "금융 거래소에서 새로운 디지털 자산이 상장된다.",
+ "투자 은행가가 고객에게 재무 계획을 제안하고 있다.",
+ "금융 회사에서 신용평가 모델을 업데이트하고 있다.",
+ "투자자들이 새로운 ICO에 참여하려고 하고 있다.",
+ "경제 전문가가 금리 인상에 대한 예측을 내리고 있다.",
+corpus_embeddings = embedder.encode(corpus, convert_to_tensor=True)
+# Financial domain queries
+queries = [
+ "한 투자자가 비트코인을 매수한다.",
+ "은행에서 대출을 상환하는 고객이 있다.",
+ "금융 분야에서 새로운 기술 동향을 조사하고 있다."
+# Find the closest 5 sentences of the corpus for each query sentence based on cosine similarity
+top_k = 5
+for query in queries:
+ query_embedding = embedder.encode(query, convert_to_tensor=True)
+ cos_scores = util.pytorch_cos_sim(query_embedding, corpus_embeddings)[0]
+ cos_scores = cos_scores.cpu()
+ # We use np.argpartition, to only partially sort the top_k results
+ top_results = np.argpartition(-cos_scores, range(top_k))[0:top_k]
+ print("\n\n======================\n\n")
+ print("Query:", query)
+ print("\nTop 5 most similar sentences in the financial corpus:")
+ for idx in top_results[0:top_k]:
+ print(corpus[idx].strip(), "(Score: %.4f)" % (cos_scores[idx]))
+Query: 한 투자자가 비트코인을 매수한다.
+Top 5 most similar sentences in the financial corpus:
+주식 시장에서 한 투자자가 주식을 매수한다. (Score: 0.7579)
+투자자들이 새로운 ICO에 참여하려고 하고 있다. (Score: 0.4809)
+금융 거래소에서 새로운 디지털 자산이 상장된다. (Score: 0.4669)
+금융 전문가가 새로운 투자 전략을 개발하고 있다. (Score: 0.3499)
+투자 은행가가 고객에게 재무 계획을 제안하고 있다. (Score: 0.3279)
+Query: 은행에서 대출을 상환하는 고객이 있다.
+Top 5 most similar sentences in the financial corpus:
+은행에서 예금을 만기로 인출하는 고객이 있다. (Score: 0.7762)
+금융 회사에서 신용평가 모델을 업데이트하고 있다. (Score: 0.3431)
+투자 은행가가 고객에게 재무 계획을 제안하고 있다. (Score: 0.3422)
+주식 시장에서 한 투자자가 주식을 매수한다. (Score: 0.2330)
+금융 거래소에서 새로운 디지털 자산이 상장된다. (Score: 0.1982)
+Query: 금융 분야에서 새로운 기술 동향을 조사하고 있다.
+Top 5 most similar sentences in the financial corpus:
+금융 거래소에서 새로운 디지털 자산이 상장된다. (Score: 0.5661)
+금융 회사에서 신용평가 모델을 업데이트하고 있다. (Score: 0.5184)
+금융 전문가가 새로운 투자 전략을 개발하고 있다. (Score: 0.5122)
+투자자들이 새로운 ICO에 참여하려고 하고 있다. (Score: 0.4111)
+투자 은행가가 고객에게 재무 계획을 제안하고 있다. (Score: 0.3708)
+## Training
+직접 모델을 파인튜닝하려면 [`kor-nlu-datasets`](https://github.com/kakaobrain/kor-nlu-datasets) 저장소를 clone 하고 `training_*.py` 스크립트를 실행시키면 됩니다.
+`train.sh` 파일에서 학습 예시를 확인할 수 있습니다.
+git clone https://github.com/upskyy/kf-deberta-multitask.git
+cd kf-deberta-multitask
+pip install -r requirements.txt
+git clone https://github.com/kakaobrain/kor-nlu-datasets.git
+python training_multi_task.py --model_name_or_path kakaobank/kf-deberta-base
+## Evaluation
+KorSTS Benchmark를 평가하는 방법입니다.
+git clone https://github.com/upskyy/kf-deberta-multitask.git
+cd kf-deberta-multitask
+pip install -r requirements.txt
+git clone https://github.com/kakaobrain/kor-nlu-datasets.git
+python bin/benchmark.py
+## Export ONNX
+`requirements.txt` 설치 후 `bin` 디렉토리에서 `export_onnx.py` 스크립트를 실행시키면 됩니다.
+git clone https://github.com/upskyy/kf-deberta-multitask.git
+cd kf-deberta-multitask
+pip install -r requirements.txt
+python bin/export_onnx.py
+## Acknowledgements
+- [kakaobank/kf-deberta-base](https://huggingface.co/kakaobank/kf-deberta-base) for pretrained model
+- [jhgan00/ko-sentence-transformers](https://github.com/jhgan00/ko-sentence-transformers) for original codebase
+- [kor-nlu-datasets](https://github.com/kakaobrain/kor-nlu-datasets) for training data
+## Citation
+ title = {KF-DeBERTa: Financial Domain-specific Pre-trained Language Model},
+ author = {Eunkwang Jeon, Jungdae Kim, Minsang Song, and Joohyun Ryu},
+ booktitle = {Proceedings of the 35th Annual Conference on Human and Cognitive Language Technology},
+ moth = {oct},
+ year = {2023},
+ publisher = {Korean Institute of Information Scientists and Engineers},
+ url = {http://www.hclt.kr/symp/?lnb=conference},
+ pages = {143--148},
+ title={KorNLI and KorSTS: New Benchmark Datasets for Korean Natural Language Understanding},
+ author={Ham, Jiyeon and Choe, Yo Joong and Park, Kyubyong and Choi, Ilji and Soh, Hyungjoon},
+ journal={arXiv preprint arXiv:2004.03289},
+ year={2020}
diff --git a/bin/benchmark.py b/bin/benchmark.py
new file mode 100644
index 0000000..efb7697
--- /dev/null
+++ b/bin/benchmark.py
@@ -0,0 +1,33 @@
+import argparse
+import csv
+import logging
+import os
+from sentence_transformers import InputExample, LoggingHandler, SentenceTransformer
+from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator
+ format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, handlers=[LoggingHandler()]
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--sts_dataset_path", type=str, default="kor-nlu-datasets/KorSTS")
+ parser.add_argument("--model_name_or_path", type=str, required=True)
+ args = parser.parse_args()
+ # Read STSbenchmark dataset and use it as development set
+ test_samples = []
+ test_file = os.path.join(args.sts_dataset_path, "sts-test.tsv")
+ with open(test_file, "rt", encoding="utf8") as fIn:
+ reader = csv.DictReader(fIn, delimiter="\t", quoting=csv.QUOTE_NONE)
+ for row in reader:
+ score = float(row["score"]) / 5.0 # Normalize score to range 0 ... 1
+ test_samples.append(InputExample(texts=[row["sentence1"], row["sentence2"]], label=score))
+ test_evaluator = EmbeddingSimilarityEvaluator.from_input_examples(test_samples, name="sts-test")
+ model = SentenceTransformer(args.model_name_or_path)
+ test_evaluator(model)
diff --git a/bin/export_onnx.py b/bin/export_onnx.py
new file mode 100644
index 0000000..01d8b78
--- /dev/null
+++ b/bin/export_onnx.py
@@ -0,0 +1,13 @@
+import os
+from pathlib import Path
+from transformers.convert_graph_to_onnx import convert
+if __name__ == "__main__":
+ output_dir = "models"
+ if not os.path.exists(output_dir):
+ os.makedirs(output_dir, exist_ok=False)
+ output_fpath = os.path.join(output_dir, "kf-deberta-multitask.onnx")
+ convert(framework="pt", model="upskyy/kf-deberta-multitask", output=Path(output_fpath), opset=15)
diff --git a/bin/train.sh b/bin/train.sh
new file mode 100755
index 0000000..b13ad85
--- /dev/null
+++ b/bin/train.sh
@@ -0,0 +1,23 @@
+# To start training, you need to download the KorNLUDatasets first.
+# git clone https://github.com/kakaobrain/kor-nlu-datasets.git
+# train on STS dataset only
+# python training_sts.py --model_name_or_path klue/bert-base
+# python training_sts.py --model_name_or_path klue/roberta-base
+# python training_sts.py --model_name_or_path klue/roberta-small
+# python training_sts.py --model_name_or_path klue/roberta-large
+python training_sts.py --model_name_or_path kakaobank/kf-deberta-base
+# train on both NLI and STS dataset (multi-task)
+# python training_multi_task.py --model_name_or_path klue/bert-base
+# python training_multi_task.py --model_name_or_path klue/roberta-base
+# python training_multi_task.py --model_name_or_path klue/roberta-small
+# python training_multi_task.py --model_name_or_path klue/roberta-large
+python training_multi_task.py --model_name_or_path kakaobank/kf-deberta-base
+# train on NLI dataset only
+# python training_nli.py --model_name_or_path klue/bert-base
+# python training_nli.py --model_name_or_path klue/roberta-base
+# python training_nli.py --model_name_or_path klue/roberta-small
+# python training_nli.py --model_name_or_path klue/roberta-large
+python training_nli.py --model_name_or_path kakaobank/kf-deberta-base
\ No newline at end of file
diff --git a/data_util.py b/data_util.py
new file mode 100644
index 0000000..adb7c37
--- /dev/null
+++ b/data_util.py
@@ -0,0 +1,54 @@
+import csv
+import random
+from sentence_transformers.readers import InputExample
+def load_kor_sts_samples(filename):
+ samples = []
+ with open(filename, "rt", encoding="utf8") as fIn:
+ reader = csv.DictReader(fIn, delimiter="\t", quoting=csv.QUOTE_NONE)
+ for row in reader:
+ score = float(row["score"]) / 5.0 # Normalize score to range 0 ... 1
+ samples.append(InputExample(texts=[row["sentence1"], row["sentence2"]], label=score))
+ return samples
+def load_kor_nli_samples(filename):
+ data = {}
+ def add_to_samples(sent1, sent2, label):
+ if sent1 not in data:
+ data[sent1] = {"contradiction": set(), "entailment": set(), "neutral": set()}
+ data[sent1][label].add(sent2)
+ with open(filename, "r", encoding="utf-8") as fIn:
+ reader = csv.DictReader(fIn, delimiter="\t", quoting=csv.QUOTE_NONE)
+ for row in reader:
+ sent1 = row["sentence1"].strip()
+ sent2 = row["sentence2"].strip()
+ add_to_samples(sent1, sent2, row["gold_label"])
+ add_to_samples(sent2, sent1, row["gold_label"]) # Also add the opposite
+ samples = []
+ for sent, others in data.items():
+ if len(others["entailment"]) > 0 and len(others["contradiction"]) > 0:
+ samples.append(
+ InputExample(
+ texts=[
+ sent,
+ random.choice(list(others["entailment"])),
+ random.choice(list(others["contradiction"])),
+ ]
+ )
+ )
+ samples.append(
+ InputExample(
+ texts=[
+ random.choice(list(others["entailment"])),
+ sent,
+ random.choice(list(others["contradiction"])),
+ ]
+ )
+ )
+ return samples
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..14da493
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,2 @@
\ No newline at end of file
diff --git a/training_multi_task.py b/training_multi_task.py
new file mode 100644
index 0000000..8a74610
--- /dev/null
+++ b/training_multi_task.py
@@ -0,0 +1,114 @@
+import argparse
+import glob
+import logging
+import math
+import os
+import random
+from datetime import datetime
+import numpy as np
+import torch
+from sentence_transformers import LoggingHandler, SentenceTransformer, datasets, losses, models
+from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator
+from torch.utils.data import DataLoader
+from data_util import load_kor_nli_samples, load_kor_sts_samples
+# Configure logger
+ format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, handlers=[LoggingHandler()]
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model_name_or_path", type=str)
+ parser.add_argument("--max_seq_length", type=int, default=256)
+ parser.add_argument("--nli_batch_size", type=int, default=64)
+ parser.add_argument("--sts_batch_size", type=int, default=8)
+ parser.add_argument("--num_epochs", type=int, default=10)
+ parser.add_argument("--output_dir", type=str, default="output")
+ parser.add_argument("--output_prefix", type=str, default="kor_multi_")
+ parser.add_argument("--seed", type=int, default=42)
+ args = parser.parse_args()
+ # Fix random seed
+ random.seed(args.seed)
+ np.random.seed(args.seed)
+ torch.manual_seed(args.seed)
+ torch.cuda.manual_seed(args.seed)
+ # Read the dataset
+ model_save_path = os.path.join(
+ args.output_dir,
+ args.output_prefix
+ + args.model_name_or_path.replace("/", "-")
+ + "-"
+ + datetime.now().strftime("%Y-%m-%d_%H-%M-%S"),
+ )
+ # Define SentenceTransformer model
+ word_embedding_model = models.Transformer(args.model_name_or_path, max_seq_length=args.max_seq_length)
+ pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), pooling_mode="mean")
+ model = SentenceTransformer(modules=[word_embedding_model, pooling_model])
+ # Read the dataset
+ nli_dataset_path = "kor-nlu-datasets/KorNLI"
+ sts_dataset_path = "kor-nlu-datasets/KorSTS"
+ logging.info("Read KorNLI train/KorSTS dev dataset")
+ # Read NLI training dataset
+ nli_train_files = glob.glob(os.path.join(nli_dataset_path, "*train.ko.tsv"))
+ dev_file = os.path.join(sts_dataset_path, "sts-dev.tsv")
+ nli_train_samples = []
+ for nli_train_file in nli_train_files:
+ nli_train_samples += load_kor_nli_samples(nli_train_file)
+ nli_train_dataloader = datasets.NoDuplicatesDataLoader(nli_train_samples, batch_size=args.nli_batch_size)
+ nli_train_loss = losses.MultipleNegativesRankingLoss(model)
+ # Read STS training dataset
+ sts_dataset_path = "kor-nlu-datasets/KorSTS"
+ sts_train_file = os.path.join(sts_dataset_path, "sts-train.tsv")
+ sts_train_samples = load_kor_sts_samples(sts_train_file)
+ sts_train_dataloader = DataLoader(sts_train_samples, shuffle=True, batch_size=args.sts_batch_size)
+ sts_train_loss = losses.CosineSimilarityLoss(model=model)
+ # Read STS dev dataset
+ dev_samples = load_kor_sts_samples(dev_file)
+ dev_evaluator = EmbeddingSimilarityEvaluator.from_input_examples(
+ dev_samples, batch_size=args.sts_batch_size, name="sts-dev"
+ )
+ # In multi-task training setting,
+ print("length of nli data loader:", len(nli_train_dataloader))
+ print("length of sts data loader:", len(sts_train_dataloader))
+ steps_per_epoch = min(len(nli_train_dataloader), len(sts_train_dataloader))
+ # Configure the training.
+ warmup_steps = math.ceil(steps_per_epoch * args.num_epochs * 0.1) # 10% of train data for warm-up
+ logging.info("Warmup-steps: {}".format(warmup_steps))
+ # Train the model
+ train_objectives = [(nli_train_dataloader, nli_train_loss), (sts_train_dataloader, sts_train_loss)]
+ model.fit(
+ train_objectives=train_objectives,
+ evaluator=dev_evaluator,
+ epochs=args.num_epochs,
+ optimizer_params={"lr": 2e-5},
+ evaluation_steps=1000,
+ warmup_steps=warmup_steps,
+ output_path=model_save_path,
+ )
+ # Load the stored model and evaluate its performance on STS benchmark dataset
+ model = SentenceTransformer(model_save_path)
+ logging.info("Read KorSTS benchmark test dataset")
+ test_file = os.path.join(sts_dataset_path, "sts-test.tsv")
+ test_samples = load_kor_sts_samples(test_file)
+ test_evaluator = EmbeddingSimilarityEvaluator.from_input_examples(test_samples, name="sts-test")
+ test_evaluator(model, output_path=model_save_path)
diff --git a/training_nli.py b/training_nli.py
new file mode 100644
index 0000000..6ffd9e1
--- /dev/null
+++ b/training_nli.py
@@ -0,0 +1,99 @@
+import argparse
+import glob
+import logging
+import math
+import os
+import random
+from datetime import datetime
+import numpy as np
+import torch
+from sentence_transformers import LoggingHandler, SentenceTransformer, datasets, losses, models
+from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator
+from data_util import load_kor_nli_samples, load_kor_sts_samples
+# Configure logger
+ format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, handlers=[LoggingHandler()]
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model_name_or_path", type=str)
+ parser.add_argument("--max_seq_length", type=int, default=128)
+ parser.add_argument("--batch_size", type=int, default=64)
+ parser.add_argument("--num_epochs", type=int, default=1)
+ parser.add_argument("--output_dir", type=str, default="output")
+ parser.add_argument("--output_prefix", type=str, default="kor_nli_")
+ parser.add_argument("--seed", type=int, default=777)
+ args = parser.parse_args()
+ # Fix random seed
+ random.seed(args.seed)
+ np.random.seed(args.seed)
+ torch.manual_seed(args.seed)
+ torch.cuda.manual_seed(args.seed)
+ # Read the dataset
+ model_save_path = os.path.join(
+ args.output_dir,
+ args.output_prefix
+ + args.model_name_or_path.replace("/", "-")
+ + "-"
+ + datetime.now().strftime("%Y-%m-%d_%H-%M-%S"),
+ )
+ # Define SentenceTransformer model
+ word_embedding_model = models.Transformer(args.model_name_or_path, max_seq_length=args.max_seq_length)
+ pooling_model = models.Pooling(
+ word_embedding_model.get_word_embedding_dimension(),
+ pooling_mode_mean_tokens=True,
+ pooling_mode_cls_token=False,
+ pooling_mode_max_tokens=False,
+ )
+ model = SentenceTransformer(modules=[word_embedding_model, pooling_model])
+ # Read the dataset
+ nli_dataset_path = "kor-nlu-datasets/KorNLI"
+ sts_dataset_path = "kor-nlu-datasets/KorSTS"
+ logging.info("Read KorNLI train/KorSTS dev dataset")
+ train_files = glob.glob(os.path.join(nli_dataset_path, "*train.ko.tsv"))
+ dev_file = os.path.join(sts_dataset_path, "sts-dev.tsv")
+ train_samples = []
+ for train_file in train_files:
+ train_samples += load_kor_nli_samples(train_file)
+ train_dataloader = datasets.NoDuplicatesDataLoader(train_samples, batch_size=args.batch_size)
+ dev_samples = load_kor_sts_samples(dev_file)
+ dev_evaluator = EmbeddingSimilarityEvaluator.from_input_examples(
+ dev_samples, batch_size=args.batch_size, name="sts-dev"
+ )
+ train_loss = losses.MultipleNegativesRankingLoss(model)
+ # Configure the training.
+ warmup_steps = math.ceil(len(train_dataloader) * args.num_epochs * 0.1) # 10% of train data for warm-up
+ logging.info("Warmup-steps: {}".format(warmup_steps))
+ # Train the model
+ model.fit(
+ train_objectives=[(train_dataloader, train_loss)],
+ evaluator=dev_evaluator,
+ epochs=args.num_epochs,
+ optimizer_params={"lr": 2e-5},
+ evaluation_steps=1000,
+ warmup_steps=warmup_steps,
+ output_path=model_save_path,
+ )
+ # Load the stored model and evaluate its performance on STS benchmark dataset
+ model = SentenceTransformer(model_save_path)
+ logging.info("Read KorSTS benchmark test dataset")
+ test_file = os.path.join(sts_dataset_path, "sts-test.tsv")
+ test_samples = load_kor_sts_samples(test_file)
+ test_evaluator = EmbeddingSimilarityEvaluator.from_input_examples(test_samples, name="sts-test")
+ test_evaluator(model, output_path=model_save_path)
diff --git a/training_sts.py b/training_sts.py
new file mode 100644
index 0000000..62e657d
--- /dev/null
+++ b/training_sts.py
@@ -0,0 +1,100 @@
+This examples trains KoBERT for the STS benchmark from scratch.
+It generates sentence embeddings that can be compared using cosine-similarity to measure the similarity.
+python training_sts.py --model_name_or_path klue/bert-base
+import argparse
+import logging
+import math
+import os
+import random
+from datetime import datetime
+import numpy as np
+import torch
+from sentence_transformers import LoggingHandler, SentenceTransformer, losses, models
+from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator
+from torch.utils.data import DataLoader
+from data_util import load_kor_sts_samples
+# Configure logger
+ format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, handlers=[LoggingHandler()]
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model_name_or_path", type=str)
+ parser.add_argument("--max_seq_length", type=int, default=128)
+ parser.add_argument("--batch_size", type=int, default=8)
+ parser.add_argument("--num_epochs", type=int, default=5)
+ parser.add_argument("--output_dir", type=str, default="output")
+ parser.add_argument("--output_prefix", type=str, default="kor_sts_")
+ parser.add_argument("--seed", type=int, default=777)
+ args = parser.parse_args()
+ # Fix random seed
+ random.seed(args.seed)
+ np.random.seed(args.seed)
+ torch.manual_seed(args.seed)
+ torch.cuda.manual_seed(args.seed)
+ # Read the dataset
+ model_save_path = os.path.join(
+ args.output_dir,
+ args.output_prefix
+ + args.model_name_or_path.replace("/", "-")
+ + "-"
+ + datetime.now().strftime("%Y-%m-%d_%H-%M-%S"),
+ )
+ # Define SentenceTransformer model
+ word_embedding_model = models.Transformer(args.model_name_or_path, max_seq_length=args.max_seq_length)
+ pooling_model = models.Pooling(
+ word_embedding_model.get_word_embedding_dimension(),
+ pooling_mode_mean_tokens=True,
+ pooling_mode_cls_token=False,
+ pooling_mode_max_tokens=False,
+ )
+ model = SentenceTransformer(modules=[word_embedding_model, pooling_model])
+ # Read the dataset
+ logging.info("Read KorSTS train/dev dataset")
+ sts_dataset_path = "kor-nlu-datasets/KorSTS"
+ train_file, dev_file = os.path.join(sts_dataset_path, "sts-train.tsv"), os.path.join(
+ sts_dataset_path, "sts-dev.tsv"
+ )
+ train_samples, dev_samples = load_kor_sts_samples(train_file), load_kor_sts_samples(dev_file)
+ train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=args.batch_size)
+ dev_evaluator = EmbeddingSimilarityEvaluator.from_input_examples(
+ dev_samples, batch_size=args.batch_size, name="sts-dev"
+ )
+ train_loss = losses.CosineSimilarityLoss(model=model)
+ # Configure the training.
+ warmup_steps = math.ceil(len(train_dataloader) * args.num_epochs * 0.1) # 10% of train data for warm-up
+ logging.info("Warmup-steps: {}".format(warmup_steps))
+ # Train the model
+ model.fit(
+ train_objectives=[(train_dataloader, train_loss)],
+ evaluator=dev_evaluator,
+ epochs=args.num_epochs,
+ optimizer_params={"lr": 2e-5},
+ evaluation_steps=1000,
+ warmup_steps=warmup_steps,
+ output_path=model_save_path,
+ )
+ # Load the stored model and evaluate its performance on STS benchmark dataset
+ model = SentenceTransformer(model_save_path)
+ logging.info("Read KorSTS benchmark test dataset")
+ test_file = os.path.join(sts_dataset_path, "sts-test.tsv")
+ test_samples = load_kor_sts_samples(test_file)
+ test_evaluator = EmbeddingSimilarityEvaluator.from_input_examples(test_samples, name="sts-test")
+ test_evaluator(model, output_path=model_save_path)