diff --git a/Chat_gpt_api.ipynb b/Chat_gpt_api.ipynb
new file mode 100644
index 0000000..fd5080b
--- /dev/null
+++ b/Chat_gpt_api.ipynb
@@ -0,0 +1,53 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "provenance": [],
+ "authorship_tag": "ABX9TyMzfOltuiIfV+7/NOvFLSPG",
+ "include_colab_link": true
+ },
+ "kernelspec": {
+ "name": "python3",
+ "display_name": "Python 3"
+ },
+ "language_info": {
+ "name": "python"
+ }
+ },
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "view-in-github",
+ "colab_type": "text"
+ },
+ "source": [
+ "
"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "_-7e9e-qjOFi"
+ },
+ "outputs": [],
+ "source": [
+ "from openai import OpenAI\n",
+ "\n",
+ "def get_completion(prompt, model=\"gpt-3.5-turbo\"):\n",
+ " client = OpenAI(api_key='API_KEY') #gpt api key 받아와서 입력\n",
+ " completion = client.chat.completions.create(\n",
+ " model=model,\n",
+ " messages=[{\"role\": \"user\", \"content\":prompt},])\n",
+ " response = completion.choices[0].message.content\n",
+ " return response\n",
+ "\n",
+ "prompt = input(\"검색하실 판례를 입력해주세요:\")\n",
+ "response = get_completion(prompt)\n",
+ "print(response)"
+ ]
+ }
+ ]
+}
\ No newline at end of file
diff --git a/README.md b/README.md
index afc4eac..d763dd3 100644
--- a/README.md
+++ b/README.md
@@ -9,4 +9,10 @@
|ML Engineering|방가윤, 허정원|
## 사용 모델
-[데이터](https://blog.lbox.kr/lbox-open)
\ No newline at end of file
+[데이터](https://blog.lbox.kr/lbox-open)
+MT5ForConditionalGeneration
+
+## 프로젝트 개요
+판례검색 Datasets으로 ibox를 이용하여 MT5ForConditionalGeneration 모델을 이용
+학습 및 테스트 진행
+모델과 chatgpt api 와 연결하여 사용
diff --git a/base_model.ipynb b/base_model.ipynb
new file mode 100644
index 0000000..666c88c
--- /dev/null
+++ b/base_model.ipynb
@@ -0,0 +1,1918 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "provenance": [],
+ "include_colab_link": true
+ },
+ "kernelspec": {
+ "name": "python3",
+ "display_name": "Python 3"
+ },
+ "language_info": {
+ "name": "python"
+ },
+ "accelerator": "GPU",
+ "widgets": {
+ "application/vnd.jupyter.widget-state+json": {
+ "a61083da9c724a3ab6e70845e57e4bcf": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "HBoxModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_24b5a18022fd4670885fa1e6a05f667e",
+ "IPY_MODEL_e9dafcde4c7a4cee90a7381d50c0b1f1",
+ "IPY_MODEL_d9a20d8c643f4c3a886089441d30567d"
+ ],
+ "layout": "IPY_MODEL_60273a3f8c8242e083fa7588fada897d"
+ }
+ },
+ "24b5a18022fd4670885fa1e6a05f667e": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "HTMLModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_d0a6e68ebefd47cb85ef55959a0e4dd3",
+ "placeholder": "",
+ "style": "IPY_MODEL_78b5bf4f12374ad69bbadc336d66574a",
+ "value": "Validation sanity check: 100%"
+ }
+ },
+ "e9dafcde4c7a4cee90a7381d50c0b1f1": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "FloatProgressModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_ca94eb51bd9c44a6a0b1608fe93b8f9f",
+ "max": 2,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_919e939889b24821b64e272e8fb66e88",
+ "value": 2
+ }
+ },
+ "d9a20d8c643f4c3a886089441d30567d": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "HTMLModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_c5bd7eb7067e4e89a5974518ce55b78b",
+ "placeholder": "",
+ "style": "IPY_MODEL_944ae00dc9d84f6ea6ad09be593c810c",
+ "value": " 2/2 [00:00<00:00, 2.38it/s]"
+ }
+ },
+ "60273a3f8c8242e083fa7588fada897d": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": "inline-flex",
+ "flex": null,
+ "flex_flow": "row wrap",
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": "hidden",
+ "width": "100%"
+ }
+ },
+ "d0a6e68ebefd47cb85ef55959a0e4dd3": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "78b5bf4f12374ad69bbadc336d66574a": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "DescriptionStyleModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "ca94eb51bd9c44a6a0b1608fe93b8f9f": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": "2",
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "919e939889b24821b64e272e8fb66e88": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "ProgressStyleModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "c5bd7eb7067e4e89a5974518ce55b78b": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "944ae00dc9d84f6ea6ad09be593c810c": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "DescriptionStyleModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "ca38c82573d947c894dc5511b0d75874": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "HBoxModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_c2dac06a8dfc4228b3d3b698c80e2378",
+ "IPY_MODEL_050f7cdd27d0449e92b8a83b37c60d4b",
+ "IPY_MODEL_abfaf40a12b348498b81eec95c8745ab"
+ ],
+ "layout": "IPY_MODEL_0317abbf1b184cd9b2e83903ff1709ef"
+ }
+ },
+ "c2dac06a8dfc4228b3d3b698c80e2378": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "HTMLModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_d528ccbea8a3486d963f0e443e1e6a9f",
+ "placeholder": "",
+ "style": "IPY_MODEL_1f6fadb0d5e046cba9b8e1af2c657b0c",
+ "value": "Epoch 0: 100%"
+ }
+ },
+ "050f7cdd27d0449e92b8a83b37c60d4b": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "FloatProgressModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_13aa11ba42e94ffcbaeb833497d9c578",
+ "max": 1063,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_bc59bdc73a704ee888577f014ac64aac",
+ "value": 1063
+ }
+ },
+ "abfaf40a12b348498b81eec95c8745ab": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "HTMLModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_36b234cb85d54c48807666c35e4c9c6b",
+ "placeholder": "",
+ "style": "IPY_MODEL_5fc2ae69dd4148aca6a86b4acc85adcd",
+ "value": " 1063/1063 [09:04<00:00, 1.95it/s, loss=1.28, v_num=1]"
+ }
+ },
+ "0317abbf1b184cd9b2e83903ff1709ef": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": "inline-flex",
+ "flex": null,
+ "flex_flow": "row wrap",
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": "100%"
+ }
+ },
+ "d528ccbea8a3486d963f0e443e1e6a9f": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "1f6fadb0d5e046cba9b8e1af2c657b0c": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "DescriptionStyleModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "13aa11ba42e94ffcbaeb833497d9c578": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": "2",
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "bc59bdc73a704ee888577f014ac64aac": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "ProgressStyleModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "36b234cb85d54c48807666c35e4c9c6b": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "5fc2ae69dd4148aca6a86b4acc85adcd": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "DescriptionStyleModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "3c7e4c2215dd47fb8d44289d81d46380": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "HBoxModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_b4e60cb8e1c74c7fa5ce959ac66f854b",
+ "IPY_MODEL_7567dc9ff7e74703b9f87bc21b002345",
+ "IPY_MODEL_64fb3d1e5f3c4447b54e1b147196625d"
+ ],
+ "layout": "IPY_MODEL_19c27c9a3905406b81ced45a093cc53c"
+ }
+ },
+ "b4e60cb8e1c74c7fa5ce959ac66f854b": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "HTMLModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_d76767df8285449aa115018c2cdac272",
+ "placeholder": "",
+ "style": "IPY_MODEL_550b6efe5bb541eb9cf3245f62befcb4",
+ "value": "Validating: 100%"
+ }
+ },
+ "7567dc9ff7e74703b9f87bc21b002345": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "FloatProgressModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_7a3eab87e135454d96232271c8b4b8e1",
+ "max": 63,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_3e4262f1135e4a27832fb8c76e286593",
+ "value": 63
+ }
+ },
+ "64fb3d1e5f3c4447b54e1b147196625d": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "HTMLModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_026d6f14244b4643acc8f8ab87af1c71",
+ "placeholder": "",
+ "style": "IPY_MODEL_1ad9b585d5934abc9a0930ef94799f63",
+ "value": " 63/63 [01:19<00:00, 1.27s/it]"
+ }
+ },
+ "19c27c9a3905406b81ced45a093cc53c": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": "inline-flex",
+ "flex": null,
+ "flex_flow": "row wrap",
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": "hidden",
+ "width": "100%"
+ }
+ },
+ "d76767df8285449aa115018c2cdac272": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "550b6efe5bb541eb9cf3245f62befcb4": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "DescriptionStyleModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "7a3eab87e135454d96232271c8b4b8e1": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": "2",
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "3e4262f1135e4a27832fb8c76e286593": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "ProgressStyleModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "026d6f14244b4643acc8f8ab87af1c71": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "1ad9b585d5934abc9a0930ef94799f63": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "DescriptionStyleModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ }
+ }
+ }
+ },
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "view-in-github",
+ "colab_type": "text"
+ },
+ "source": [
+ "
"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "#1. Install\n",
+ "warning은 무시해도 됩니다\n",
+ "임의로 버전을 고정한거라 뜨는 메세지 같아요"
+ ],
+ "metadata": {
+ "id": "Z6v4imMIm5gs"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "!pip install --upgrade setuptools\n",
+ "!pip install fastapi kaleido python-multipart uvicorn\n",
+ "#!pip install jedi\n",
+ "!pip install transformers==4.16.2\n",
+ "!pip install sentencepiece\n",
+ "!pip install datasets\n",
+ "!pip install rouge_score\n",
+ "!pip install pytorch_lightning==1.5.10"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 1000
+ },
+ "id": "USqNXkEMjdE2",
+ "outputId": "cd9704e1-b993-46bd-a8cd-8aaddcb7b607"
+ },
+ "execution_count": null,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Requirement already satisfied: setuptools in /usr/local/lib/python3.10/dist-packages (59.5.0)\n",
+ "Collecting setuptools\n",
+ " Using cached setuptools-68.2.2-py3-none-any.whl (807 kB)\n",
+ "Installing collected packages: setuptools\n",
+ " Attempting uninstall: setuptools\n",
+ " Found existing installation: setuptools 59.5.0\n",
+ " Uninstalling setuptools-59.5.0:\n",
+ " Successfully uninstalled setuptools-59.5.0\n",
+ "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
+ "ipython 7.34.0 requires jedi>=0.16, which is not installed.\n",
+ "pytorch-lightning 1.5.10 requires setuptools==59.5.0, but you have setuptools 68.2.2 which is incompatible.\u001b[0m\u001b[31m\n",
+ "\u001b[0mSuccessfully installed setuptools-68.2.2\n"
+ ]
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "application/vnd.colab-display-data+json": {
+ "pip_warning": {
+ "packages": [
+ "_distutils_hack",
+ "pkg_resources",
+ "setuptools"
+ ]
+ }
+ }
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Requirement already satisfied: fastapi in /usr/local/lib/python3.10/dist-packages (0.104.1)\n",
+ "Requirement already satisfied: kaleido in /usr/local/lib/python3.10/dist-packages (0.2.1)\n",
+ "Requirement already satisfied: python-multipart in /usr/local/lib/python3.10/dist-packages (0.0.6)\n",
+ "Requirement already satisfied: uvicorn in /usr/local/lib/python3.10/dist-packages (0.24.0.post1)\n",
+ "Requirement already satisfied: anyio<4.0.0,>=3.7.1 in /usr/local/lib/python3.10/dist-packages (from fastapi) (3.7.1)\n",
+ "Requirement already satisfied: pydantic!=1.8,!=1.8.1,!=2.0.0,!=2.0.1,!=2.1.0,<3.0.0,>=1.7.4 in /usr/local/lib/python3.10/dist-packages (from fastapi) (1.10.13)\n",
+ "Requirement already satisfied: starlette<0.28.0,>=0.27.0 in /usr/local/lib/python3.10/dist-packages (from fastapi) (0.27.0)\n",
+ "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from fastapi) (4.8.0)\n",
+ "Requirement already satisfied: click>=7.0 in /usr/local/lib/python3.10/dist-packages (from uvicorn) (8.1.7)\n",
+ "Requirement already satisfied: h11>=0.8 in /usr/local/lib/python3.10/dist-packages (from uvicorn) (0.14.0)\n",
+ "Requirement already satisfied: idna>=2.8 in /usr/local/lib/python3.10/dist-packages (from anyio<4.0.0,>=3.7.1->fastapi) (3.4)\n",
+ "Requirement already satisfied: sniffio>=1.1 in /usr/local/lib/python3.10/dist-packages (from anyio<4.0.0,>=3.7.1->fastapi) (1.3.0)\n",
+ "Requirement already satisfied: exceptiongroup in /usr/local/lib/python3.10/dist-packages (from anyio<4.0.0,>=3.7.1->fastapi) (1.1.3)\n",
+ "Requirement already satisfied: transformers==4.16.2 in /usr/local/lib/python3.10/dist-packages (4.16.2)\n",
+ "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers==4.16.2) (3.13.1)\n",
+ "Requirement already satisfied: huggingface-hub<1.0,>=0.1.0 in /usr/local/lib/python3.10/dist-packages (from transformers==4.16.2) (0.19.2)\n",
+ "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers==4.16.2) (1.23.5)\n",
+ "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers==4.16.2) (23.2)\n",
+ "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers==4.16.2) (6.0.1)\n",
+ "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers==4.16.2) (2023.6.3)\n",
+ "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers==4.16.2) (2.31.0)\n",
+ "Requirement already satisfied: sacremoses in /usr/local/lib/python3.10/dist-packages (from transformers==4.16.2) (0.1.1)\n",
+ "Requirement already satisfied: tokenizers!=0.11.3,>=0.10.1 in /usr/local/lib/python3.10/dist-packages (from transformers==4.16.2) (0.15.0)\n",
+ "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers==4.16.2) (4.66.1)\n",
+ "Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.1.0->transformers==4.16.2) (2023.6.0)\n",
+ "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.1.0->transformers==4.16.2) (4.8.0)\n",
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->transformers==4.16.2) (3.3.2)\n",
+ "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers==4.16.2) (3.4)\n",
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers==4.16.2) (2.0.7)\n",
+ "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers==4.16.2) (2023.7.22)\n",
+ "Requirement already satisfied: click in /usr/local/lib/python3.10/dist-packages (from sacremoses->transformers==4.16.2) (8.1.7)\n",
+ "Requirement already satisfied: joblib in /usr/local/lib/python3.10/dist-packages (from sacremoses->transformers==4.16.2) (1.3.2)\n",
+ "Requirement already satisfied: sentencepiece in /usr/local/lib/python3.10/dist-packages (0.1.99)\n",
+ "Requirement already satisfied: datasets in /usr/local/lib/python3.10/dist-packages (2.14.7)\n",
+ "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from datasets) (1.23.5)\n",
+ "Requirement already satisfied: pyarrow>=8.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (9.0.0)\n",
+ "Requirement already satisfied: pyarrow-hotfix in /usr/local/lib/python3.10/dist-packages (from datasets) (0.5)\n",
+ "Requirement already satisfied: dill<0.3.8,>=0.3.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.3.7)\n",
+ "Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets) (1.5.3)\n",
+ "Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (2.31.0)\n",
+ "Requirement already satisfied: tqdm>=4.62.1 in /usr/local/lib/python3.10/dist-packages (from datasets) (4.66.1)\n",
+ "Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets) (3.4.1)\n",
+ "Requirement already satisfied: multiprocess in /usr/local/lib/python3.10/dist-packages (from datasets) (0.70.15)\n",
+ "Requirement already satisfied: fsspec[http]<=2023.10.0,>=2023.1.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (2023.6.0)\n",
+ "Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets) (3.8.6)\n",
+ "Requirement already satisfied: huggingface-hub<1.0.0,>=0.14.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.19.2)\n",
+ "Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from datasets) (23.2)\n",
+ "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from datasets) (6.0.1)\n",
+ "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (23.1.0)\n",
+ "Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (3.3.2)\n",
+ "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (6.0.4)\n",
+ "Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (4.0.3)\n",
+ "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.9.2)\n",
+ "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.4.0)\n",
+ "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.3.1)\n",
+ "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0.0,>=0.14.0->datasets) (3.13.1)\n",
+ "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0.0,>=0.14.0->datasets) (4.8.0)\n",
+ "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets) (3.4)\n",
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets) (2.0.7)\n",
+ "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets) (2023.7.22)\n",
+ "Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2.8.2)\n",
+ "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2023.3.post1)\n",
+ "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.1->pandas->datasets) (1.16.0)\n",
+ "Requirement already satisfied: rouge_score in /usr/local/lib/python3.10/dist-packages (0.1.2)\n",
+ "Requirement already satisfied: absl-py in /usr/local/lib/python3.10/dist-packages (from rouge_score) (1.4.0)\n",
+ "Requirement already satisfied: nltk in /usr/local/lib/python3.10/dist-packages (from rouge_score) (3.8.1)\n",
+ "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from rouge_score) (1.23.5)\n",
+ "Requirement already satisfied: six>=1.14.0 in /usr/local/lib/python3.10/dist-packages (from rouge_score) (1.16.0)\n",
+ "Requirement already satisfied: click in /usr/local/lib/python3.10/dist-packages (from nltk->rouge_score) (8.1.7)\n",
+ "Requirement already satisfied: joblib in /usr/local/lib/python3.10/dist-packages (from nltk->rouge_score) (1.3.2)\n",
+ "Requirement already satisfied: regex>=2021.8.3 in /usr/local/lib/python3.10/dist-packages (from nltk->rouge_score) (2023.6.3)\n",
+ "Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from nltk->rouge_score) (4.66.1)\n",
+ "Requirement already satisfied: pytorch_lightning==1.5.10 in /usr/local/lib/python3.10/dist-packages (1.5.10)\n",
+ "Requirement already satisfied: numpy>=1.17.2 in /usr/local/lib/python3.10/dist-packages (from pytorch_lightning==1.5.10) (1.23.5)\n",
+ "Requirement already satisfied: torch>=1.7.* in /usr/local/lib/python3.10/dist-packages (from pytorch_lightning==1.5.10) (2.1.0+cu118)\n",
+ "Requirement already satisfied: future>=0.17.1 in /usr/local/lib/python3.10/dist-packages (from pytorch_lightning==1.5.10) (0.18.3)\n",
+ "Requirement already satisfied: tqdm>=4.41.0 in /usr/local/lib/python3.10/dist-packages (from pytorch_lightning==1.5.10) (4.66.1)\n",
+ "Requirement already satisfied: PyYAML>=5.1 in /usr/local/lib/python3.10/dist-packages (from pytorch_lightning==1.5.10) (6.0.1)\n",
+ "Requirement already satisfied: fsspec[http]!=2021.06.0,>=2021.05.0 in /usr/local/lib/python3.10/dist-packages (from pytorch_lightning==1.5.10) (2023.6.0)\n",
+ "Requirement already satisfied: tensorboard>=2.2.0 in /usr/local/lib/python3.10/dist-packages (from pytorch_lightning==1.5.10) (2.14.1)\n",
+ "Requirement already satisfied: torchmetrics>=0.4.1 in /usr/local/lib/python3.10/dist-packages (from pytorch_lightning==1.5.10) (1.2.0)\n",
+ "Requirement already satisfied: pyDeprecate==0.3.1 in /usr/local/lib/python3.10/dist-packages (from pytorch_lightning==1.5.10) (0.3.1)\n",
+ "Requirement already satisfied: packaging>=17.0 in /usr/local/lib/python3.10/dist-packages (from pytorch_lightning==1.5.10) (23.2)\n",
+ "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from pytorch_lightning==1.5.10) (4.8.0)\n",
+ "Collecting setuptools==59.5.0 (from pytorch_lightning==1.5.10)\n",
+ " Using cached setuptools-59.5.0-py3-none-any.whl (952 kB)\n",
+ "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning==1.5.10) (2.31.0)\n",
+ "Requirement already satisfied: aiohttp!=4.0.0a0,!=4.0.0a1 in /usr/local/lib/python3.10/dist-packages (from fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning==1.5.10) (3.8.6)\n",
+ "Requirement already satisfied: absl-py>=0.4 in /usr/local/lib/python3.10/dist-packages (from tensorboard>=2.2.0->pytorch_lightning==1.5.10) (1.4.0)\n",
+ "Requirement already satisfied: grpcio>=1.48.2 in /usr/local/lib/python3.10/dist-packages (from tensorboard>=2.2.0->pytorch_lightning==1.5.10) (1.59.2)\n",
+ "Requirement already satisfied: google-auth<3,>=1.6.3 in /usr/local/lib/python3.10/dist-packages (from tensorboard>=2.2.0->pytorch_lightning==1.5.10) (2.17.3)\n",
+ "Requirement already satisfied: google-auth-oauthlib<1.1,>=0.5 in /usr/local/lib/python3.10/dist-packages (from tensorboard>=2.2.0->pytorch_lightning==1.5.10) (1.0.0)\n",
+ "Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.10/dist-packages (from tensorboard>=2.2.0->pytorch_lightning==1.5.10) (3.5.1)\n",
+ "Requirement already satisfied: protobuf>=3.19.6 in /usr/local/lib/python3.10/dist-packages (from tensorboard>=2.2.0->pytorch_lightning==1.5.10) (3.20.3)\n",
+ "Requirement already satisfied: six>1.9 in /usr/local/lib/python3.10/dist-packages (from tensorboard>=2.2.0->pytorch_lightning==1.5.10) (1.16.0)\n",
+ "Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /usr/local/lib/python3.10/dist-packages (from tensorboard>=2.2.0->pytorch_lightning==1.5.10) (0.7.2)\n",
+ "Requirement already satisfied: werkzeug>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from tensorboard>=2.2.0->pytorch_lightning==1.5.10) (3.0.1)\n",
+ "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=1.7.*->pytorch_lightning==1.5.10) (3.13.1)\n",
+ "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.7.*->pytorch_lightning==1.5.10) (1.12)\n",
+ "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.7.*->pytorch_lightning==1.5.10) (3.2.1)\n",
+ "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.7.*->pytorch_lightning==1.5.10) (3.1.2)\n",
+ "Requirement already satisfied: triton==2.1.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.7.*->pytorch_lightning==1.5.10) (2.1.0)\n",
+ "Requirement already satisfied: lightning-utilities>=0.8.0 in /usr/local/lib/python3.10/dist-packages (from torchmetrics>=0.4.1->pytorch_lightning==1.5.10) (0.9.0)\n",
+ "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning==1.5.10) (23.1.0)\n",
+ "Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning==1.5.10) (3.3.2)\n",
+ "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning==1.5.10) (6.0.4)\n",
+ "Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.10/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning==1.5.10) (4.0.3)\n",
+ "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning==1.5.10) (1.9.2)\n",
+ "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning==1.5.10) (1.4.0)\n",
+ "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning==1.5.10) (1.3.1)\n",
+ "Requirement already satisfied: cachetools<6.0,>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from google-auth<3,>=1.6.3->tensorboard>=2.2.0->pytorch_lightning==1.5.10) (5.3.2)\n",
+ "Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.10/dist-packages (from google-auth<3,>=1.6.3->tensorboard>=2.2.0->pytorch_lightning==1.5.10) (0.3.0)\n",
+ "Requirement already satisfied: rsa<5,>=3.1.4 in /usr/local/lib/python3.10/dist-packages (from google-auth<3,>=1.6.3->tensorboard>=2.2.0->pytorch_lightning==1.5.10) (4.9)\n",
+ "Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.10/dist-packages (from google-auth-oauthlib<1.1,>=0.5->tensorboard>=2.2.0->pytorch_lightning==1.5.10) (1.3.1)\n",
+ "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning==1.5.10) (3.4)\n",
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning==1.5.10) (2.0.7)\n",
+ "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning==1.5.10) (2023.7.22)\n",
+ "Requirement already satisfied: MarkupSafe>=2.1.1 in /usr/local/lib/python3.10/dist-packages (from werkzeug>=1.0.1->tensorboard>=2.2.0->pytorch_lightning==1.5.10) (2.1.3)\n",
+ "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.7.*->pytorch_lightning==1.5.10) (1.3.0)\n",
+ "Requirement already satisfied: pyasn1<0.6.0,>=0.4.6 in /usr/local/lib/python3.10/dist-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard>=2.2.0->pytorch_lightning==1.5.10) (0.5.0)\n",
+ "Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.10/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<1.1,>=0.5->tensorboard>=2.2.0->pytorch_lightning==1.5.10) (3.2.2)\n",
+ "Installing collected packages: setuptools\n",
+ " Attempting uninstall: setuptools\n",
+ " Found existing installation: setuptools 68.2.2\n",
+ " Uninstalling setuptools-68.2.2:\n",
+ " Successfully uninstalled setuptools-68.2.2\n",
+ "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
+ "ipython 7.34.0 requires jedi>=0.16, which is not installed.\n",
+ "arviz 0.15.1 requires setuptools>=60.0.0, but you have setuptools 59.5.0 which is incompatible.\n",
+ "cvxpy 1.3.2 requires setuptools>65.5.1, but you have setuptools 59.5.0 which is incompatible.\u001b[0m\u001b[31m\n",
+ "\u001b[0mSuccessfully installed setuptools-59.5.0\n"
+ ]
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "application/vnd.colab-display-data+json": {
+ "pip_warning": {
+ "packages": [
+ "_distutils_hack",
+ "pkg_resources",
+ "setuptools"
+ ]
+ }
+ }
+ },
+ "metadata": {}
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "!git clone https://github.com/lbox-kr/lbox_open.git --branch v0.1\n",
+ "%cd lbox_open"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "q3YgO47Hxz9f",
+ "outputId": "8c806186-44b3-4cb6-f61e-8894c6e50d0a"
+ },
+ "execution_count": null,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Cloning into 'lbox_open'...\n",
+ "remote: Enumerating objects: 266, done.\u001b[K\n",
+ "remote: Counting objects: 100% (266/266), done.\u001b[K\n",
+ "remote: Compressing objects: 100% (161/161), done.\u001b[K\n",
+ "remote: Total 266 (delta 128), reused 210 (delta 94), pack-reused 0\u001b[K\n",
+ "Receiving objects: 100% (266/266), 86.64 KiB | 8.66 MiB/s, done.\n",
+ "Resolving deltas: 100% (128/128), done.\n",
+ "Note: switching to '5dd718fe2424d84bc73a78890b357b429c2f9cb0'.\n",
+ "\n",
+ "You are in 'detached HEAD' state. You can look around, make experimental\n",
+ "changes and commit them, and you can discard any commits you make in this\n",
+ "state without impacting any branches by switching back to a branch.\n",
+ "\n",
+ "If you want to create a new branch to retain commits you create, you may\n",
+ "do so (now or later) by using -c with the switch command. Example:\n",
+ "\n",
+ " git switch -c \n",
+ "\n",
+ "Or undo this operation with:\n",
+ "\n",
+ " git switch -\n",
+ "\n",
+ "Turn off this advice by setting config variable advice.detachedHead to false\n",
+ "\n",
+ "/content/lbox_open/lbox_open\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "import os\n",
+ "from argparse import Namespace\n",
+ "\n",
+ "import torch\n",
+ "import transformers\n",
+ "import pytorch_lightning as pl\n",
+ "\n",
+ "from lcube.data_module.data_lbox_open import LBoxOpenDataModule\n",
+ "from lcube.model.model_baseline import SeqToSeqBaseline\n",
+ "\n",
+ "device = 'cuda' if torch.cuda.is_available() else \"cpu\""
+ ],
+ "metadata": {
+ "id": "JM_q343LkFbq"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "#2.Parameter 생성"
+ ],
+ "metadata": {
+ "id": "gKEO97qomg3U"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "args = Namespace()\n",
+ "# dataset\n",
+ "args.dataset_card = \"lbox/lbox_open\"\n",
+ "args.task = \"casename_classification\" # comment and uncomment following lines depending on the task\n",
+ "# args.task = \"statute_classification\"\n",
+ "# args.task = \"summarization\"\n",
+ "\n",
+ "\n",
+ "if args.task in [\"casename_classification\", \"statute_classification\"]:\n",
+ " args.input_key = \"facts\"\n",
+ "\n",
+ " # model\n",
+ " args.model_card = \"google/mt5-small\"\n",
+ " args.max_input_len = 512\n",
+ " args.max_target_len = 64\n",
+ "\n",
+ " # train\n",
+ " args.max_epochs = 10\n",
+ " args.learning_rate = 2e-4\n",
+ " args.batch_size = 8\n",
+ " args.batch_size_eval = 2 * args.batch_size\n",
+ " args.accumulate_grad_batches = 1\n",
+ " args.validation_metric = \"exact_match\"\n",
+ "\n",
+ "elif args.task == \"summarization\":\n",
+ " args.input_key = \"precedent\"\n",
+ "\n",
+ " # model\n",
+ " args.model_card = \"google/mt5-small\"\n",
+ " args.max_input_len = 768\n",
+ " args.max_target_len = 512\n",
+ "\n",
+ " # train\n",
+ " args.max_epochs = 10\n",
+ " args.learning_rate = 2e-4\n",
+ " args.batch_size = 1\n",
+ " args.batch_size_eval = 2 * args.batch_size\n",
+ " args.accumulate_grad_batches = 8\n",
+ " args.validation_metric = \"rougeL\"\n",
+ "\n",
+ "else:\n",
+ " raise ValueError\n",
+ "\n",
+ "\n",
+ "args.tokenizer = transformers.MT5TokenizerFast.from_pretrained(args.model_card)\n",
+ "pl.seed_everything(seed=1, workers=False)"
+ ],
+ "metadata": {
+ "id": "CV56eKVyn8fr",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "outputId": "fa480f20-3a1b-44c2-dd82-438b18a99256"
+ },
+ "execution_count": null,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n",
+ "INFO:pytorch_lightning.utilities.seed:Global seed set to 1\n"
+ ]
+ },
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "1"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 44
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# 3. Datasets 불러오기"
+ ],
+ "metadata": {
+ "id": "_iC4_4__oPca"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "data_module = LBoxOpenDataModule(\n",
+ " args.dataset_card,\n",
+ " args.task,\n",
+ " args.tokenizer,\n",
+ " args.max_input_len,\n",
+ " args.max_target_len,\n",
+ " args.batch_size,\n",
+ " args.batch_size_eval,\n",
+ ")"
+ ],
+ "metadata": {
+ "id": "uyerxm0woUbl"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# 4. 모델 준비\n",
+ "MT5ForConditionalGeneration"
+ ],
+ "metadata": {
+ "id": "8DpPA9wPogHW"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "backbone = transformers.MT5ForConditionalGeneration.from_pretrained(args.model_card)\n",
+ "model = SeqToSeqBaseline(\n",
+ " args.task,\n",
+ " backbone,\n",
+ " args.tokenizer,\n",
+ " args.learning_rate,\n",
+ " args.max_target_len,\n",
+ " args.validation_metric\n",
+ "\n",
+ ")"
+ ],
+ "metadata": {
+ "id": "P_Pz746Jol4S"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# 5. 학습 준비\n",
+ "epoch 5~10 정도 해야할듯 합니다\n",
+ "\n",
+ "\n"
+ ],
+ "metadata": {
+ "id": "_2q1EoVjo1qr"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# PyTorch Lightning을 사용하여 모델을 훈련하고 지정된\n",
+ "# 검증 메트릭에 대한 최고 성능을 갖는 모델을 저장하기 위해\n",
+ "# ModelCheckpoint 콜백을 설정합니다.\n",
+ "callbacks = pl.callbacks.ModelCheckpoint(\n",
+ " #monitor=args.validation_metric, # 검증 메트릭을 모니터링하여 최고 성능 모델을 찾습니다.\n",
+ " monitor='val_loss',\n",
+ " dirpath=f\"./saved/0/{args.task}\", # 최고 성능 모델이 저장될 디렉토리입니다.\n",
+ " save_top_k=1, # 최고의 모델 하나만 저장합니다.\n",
+ " mode=\"max\", # 모니터링 메트릭의 최대 값을 가진 모델을 저장합니다.\n",
+ ")\n",
+ "\n",
+ "# PyTorch Lightning Trainer 인스턴스를 생성합니다.\n",
+ "trainer = pl.Trainer(\n",
+ " max_epochs= 2, # 최대 에폭 수는 1입니다. 더 많은 에폭으로 훈련하려면 10으로 주석 처리된 줄을 사용할 수 있습니다.\n",
+ " #max_epochs = 10,\n",
+ " gpus=torch.cuda.device_count(), # 사용할 GPU 수를 설정합니다.\n",
+ " accumulate_grad_batches=args.accumulate_grad_batches, # 그래디언트 누적 배치 수를 설정합니다.\n",
+ " fast_dev_run=not True, # True로 설정하면 훈련 및 검증을 위해 하나의 배치만 실행됩니다.\n",
+ " callbacks=callbacks, # 사용할 콜백을 설정합니다.\n",
+ "\n",
+ ")"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "NRqnxtjOwwRH",
+ "outputId": "f6335b9c-2c8b-468a-9b18-409c618d9dc5"
+ },
+ "execution_count": null,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "INFO:pytorch_lightning.utilities.distributed:GPU available: True, used: True\n",
+ "INFO:pytorch_lightning.utilities.distributed:TPU available: False, using: 0 TPU cores\n",
+ "INFO:pytorch_lightning.utilities.distributed:IPU available: False, using: 0 IPUs\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# 6. 학습\n",
+ "평균 1회당 9분 정도 소요"
+ ],
+ "metadata": {
+ "id": "oTAV9ONrh1P3"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "trainer.fit(model, data_module)"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 769,
+ "referenced_widgets": [
+ "a61083da9c724a3ab6e70845e57e4bcf",
+ "24b5a18022fd4670885fa1e6a05f667e",
+ "e9dafcde4c7a4cee90a7381d50c0b1f1",
+ "d9a20d8c643f4c3a886089441d30567d",
+ "60273a3f8c8242e083fa7588fada897d",
+ "d0a6e68ebefd47cb85ef55959a0e4dd3",
+ "78b5bf4f12374ad69bbadc336d66574a",
+ "ca94eb51bd9c44a6a0b1608fe93b8f9f",
+ "919e939889b24821b64e272e8fb66e88",
+ "c5bd7eb7067e4e89a5974518ce55b78b",
+ "944ae00dc9d84f6ea6ad09be593c810c",
+ "ca38c82573d947c894dc5511b0d75874",
+ "c2dac06a8dfc4228b3d3b698c80e2378",
+ "050f7cdd27d0449e92b8a83b37c60d4b",
+ "abfaf40a12b348498b81eec95c8745ab",
+ "0317abbf1b184cd9b2e83903ff1709ef",
+ "d528ccbea8a3486d963f0e443e1e6a9f",
+ "1f6fadb0d5e046cba9b8e1af2c657b0c",
+ "13aa11ba42e94ffcbaeb833497d9c578",
+ "bc59bdc73a704ee888577f014ac64aac",
+ "36b234cb85d54c48807666c35e4c9c6b",
+ "5fc2ae69dd4148aca6a86b4acc85adcd",
+ "3c7e4c2215dd47fb8d44289d81d46380",
+ "b4e60cb8e1c74c7fa5ce959ac66f854b",
+ "7567dc9ff7e74703b9f87bc21b002345",
+ "64fb3d1e5f3c4447b54e1b147196625d",
+ "19c27c9a3905406b81ced45a093cc53c",
+ "d76767df8285449aa115018c2cdac272",
+ "550b6efe5bb541eb9cf3245f62befcb4",
+ "7a3eab87e135454d96232271c8b4b8e1",
+ "3e4262f1135e4a27832fb8c76e286593",
+ "026d6f14244b4643acc8f8ab87af1c71",
+ "1ad9b585d5934abc9a0930ef94799f63"
+ ]
+ },
+ "id": "JDJ92SEth4CY",
+ "outputId": "ffa58c93-1121-4a30-bf2a-b9f1369c5ac2"
+ },
+ "execution_count": null,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "INFO:pytorch_lightning.accelerators.gpu:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
+ "INFO:pytorch_lightning.callbacks.model_summary:\n",
+ " | Name | Type | Params\n",
+ "------------------------------------------------------\n",
+ "0 | model | MT5ForConditionalGeneration | 300 M \n",
+ "------------------------------------------------------\n",
+ "300 M Trainable params\n",
+ "0 Non-trainable params\n",
+ "300 M Total params\n",
+ "1,200.707 Total estimated model params size (MB)\n"
+ ]
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "Validation sanity check: 0it [00:00, ?it/s]"
+ ],
+ "application/vnd.jupyter.widget-view+json": {
+ "version_major": 2,
+ "version_minor": 0,
+ "model_id": "a61083da9c724a3ab6e70845e57e4bcf"
+ }
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "INFO:pytorch_lightning.utilities.seed:Global seed set to 1\n"
+ ]
+ },
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "metric: exact_match, score: 0.0\n",
+ "Validation test\n",
+ "ground truth: 감염병의예방및관리에관한법률위반\n",
+ "prediction: 하였다.\n"
+ ]
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "Training: 0it [00:00, ?it/s]"
+ ],
+ "application/vnd.jupyter.widget-view+json": {
+ "version_major": 2,
+ "version_minor": 0,
+ "model_id": "ca38c82573d947c894dc5511b0d75874"
+ }
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "Validating: 0it [00:00, ?it/s]"
+ ],
+ "application/vnd.jupyter.widget-view+json": {
+ "version_major": 2,
+ "version_minor": 0,
+ "model_id": "3c7e4c2215dd47fb8d44289d81d46380"
+ }
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "metric: exact_match, score: 0.0\n",
+ "Validation test\n",
+ "ground truth: 감염병의예방및관리에관한법률위반\n",
+ "prediction: \n"
+ ]
+ },
+ {
+ "output_type": "error",
+ "ename": "MisconfigurationException",
+ "evalue": "ignored",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mMisconfigurationException\u001b[0m Traceback (most recent call last)",
+ "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdata_module\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
+ "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, train_dataloader, ckpt_path)\u001b[0m\n\u001b[1;32m 738\u001b[0m )\n\u001b[1;32m 739\u001b[0m \u001b[0mtrain_dataloaders\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrain_dataloader\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 740\u001b[0;31m self._call_and_handle_interrupt(\n\u001b[0m\u001b[1;32m 741\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_fit_impl\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain_dataloaders\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mval_dataloaders\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdatamodule\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mckpt_path\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 742\u001b[0m )\n",
+ "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py\u001b[0m in \u001b[0;36m_call_and_handle_interrupt\u001b[0;34m(self, trainer_fn, *args, **kwargs)\u001b[0m\n\u001b[1;32m 683\u001b[0m \"\"\"\n\u001b[1;32m 684\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 685\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mtrainer_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 686\u001b[0m \u001b[0;31m# TODO: treat KeyboardInterrupt as BaseException (delete the code below) in v1.7\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 687\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mKeyboardInterrupt\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mexception\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py\u001b[0m in \u001b[0;36m_fit_impl\u001b[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)\u001b[0m\n\u001b[1;32m 775\u001b[0m \u001b[0;31m# TODO: ckpt_path only in v1.7\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 776\u001b[0m \u001b[0mckpt_path\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mckpt_path\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mresume_from_checkpoint\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 777\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_run\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mckpt_path\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mckpt_path\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 778\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 779\u001b[0m \u001b[0;32massert\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstate\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstopped\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py\u001b[0m in \u001b[0;36m_run\u001b[0;34m(self, model, ckpt_path)\u001b[0m\n\u001b[1;32m 1197\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1198\u001b[0m \u001b[0;31m# dispatch `start_training` or `start_evaluating` or `start_predicting`\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1199\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_dispatch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1200\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1201\u001b[0m \u001b[0;31m# plugin will finalized fitting (e.g. ddp_spawn will load trained model)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py\u001b[0m in \u001b[0;36m_dispatch\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1277\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtraining_type_plugin\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstart_predicting\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1278\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1279\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtraining_type_plugin\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstart_training\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1280\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1281\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mrun_stage\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py\u001b[0m in \u001b[0;36mstart_training\u001b[0;34m(self, trainer)\u001b[0m\n\u001b[1;32m 200\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mstart_training\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrainer\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m\"pl.Trainer\"\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 201\u001b[0m \u001b[0;31m# double dispatch to initiate the training loop\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 202\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_results\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun_stage\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 203\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 204\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mstart_evaluating\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrainer\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m\"pl.Trainer\"\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py\u001b[0m in \u001b[0;36mrun_stage\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1287\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpredicting\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1288\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_run_predict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1289\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_run_train\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1290\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1291\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_pre_training_routine\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py\u001b[0m in \u001b[0;36m_run_train\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1317\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit_loop\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrainer\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1318\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_detect_anomaly\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_detect_anomaly\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1319\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit_loop\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1320\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1321\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_run_evaluate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0m_EVALUATE_OUTPUT\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/base.py\u001b[0m in \u001b[0;36mrun\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 143\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 144\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_advance_start\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 145\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madvance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 146\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_advance_end\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 147\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrestarting\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/fit_loop.py\u001b[0m in \u001b[0;36madvance\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 232\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 233\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprofiler\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprofile\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"run_training_epoch\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 234\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mepoch_loop\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata_fetcher\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 235\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 236\u001b[0m \u001b[0;31m# the global step is manually decreased here due to backwards compatibility with existing loggers\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/base.py\u001b[0m in \u001b[0;36mrun\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 149\u001b[0m \u001b[0;32mbreak\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 150\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 151\u001b[0;31m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_run_end\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 152\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0moutput\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 153\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py\u001b[0m in \u001b[0;36mon_run_end\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 296\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 297\u001b[0m \u001b[0;31m# call train epoch end hooks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 298\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcall_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"on_train_epoch_end\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 299\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcall_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"on_epoch_end\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 300\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlogger_connector\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_epoch_end\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py\u001b[0m in \u001b[0;36mcall_hook\u001b[0;34m(self, hook_name, pl_module, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1493\u001b[0m \u001b[0mcallback_fx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgetattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhook_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1494\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mcallable\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcallback_fx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1495\u001b[0;31m \u001b[0mcallback_fx\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1496\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1497\u001b[0m \u001b[0;31m# next call hook in lightningModule\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/callback_hook.py\u001b[0m in \u001b[0;36mon_train_epoch_end\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 91\u001b[0m \u001b[0;34m\"\"\"Called when the epoch ends.\"\"\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 92\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mcallback\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcallbacks\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 93\u001b[0;31m \u001b[0mcallback\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_train_epoch_end\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlightning_module\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 94\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 95\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mon_validation_epoch_start\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/pytorch_lightning/callbacks/model_checkpoint.py\u001b[0m in \u001b[0;36mon_train_epoch_end\u001b[0;34m(self, trainer, pl_module)\u001b[0m\n\u001b[1;32m 319\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcurrent_epoch\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_every_n_epochs\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 320\u001b[0m ):\n\u001b[0;32m--> 321\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msave_checkpoint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 322\u001b[0m \u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit_loop\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mglobal_step\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 323\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/pytorch_lightning/callbacks/model_checkpoint.py\u001b[0m in \u001b[0;36msave_checkpoint\u001b[0;34m(self, trainer)\u001b[0m\n\u001b[1;32m 383\u001b[0m \u001b[0mglobal_step\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mglobal_step\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 384\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 385\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_validate_monitor_key\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 386\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 387\u001b[0m \u001b[0;31m# track epoch when ckpt was last checked\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/pytorch_lightning/callbacks/model_checkpoint.py\u001b[0m in \u001b[0;36m_validate_monitor_key\u001b[0;34m(self, trainer)\u001b[0m\n\u001b[1;32m 644\u001b[0m \u001b[0mwarning_cache\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwarn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mm\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 645\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 646\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mMisconfigurationException\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mm\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 647\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 648\u001b[0m def _get_metric_interpolated_filepath_name(\n",
+ "\u001b[0;31mMisconfigurationException\u001b[0m: ModelCheckpoint(monitor='val_loss') not found in the returned metrics: ['exact_match', 'training__ave_loss']. HINT: Did you call self.log('val_loss', value) in the LightningModule?"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "class YourLightningModule(pl.LightningModule):\n",
+ " def training_step(self, batch, batch_idx):\n",
+ " # 훈련 로직을 여기에 추가\n",
+ " loss = ... # 훈련 손실 계산\n",
+ "\n",
+ " # 훈련 손실을 로그에 기록합니다\n",
+ " self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)\n",
+ " return loss\n",
+ "\n",
+ " def validation_step(self, batch, batch_idx):\n",
+ " # 검증 로직을 여기에 추가\n",
+ " loss = ... # 검증 손실 계산\n",
+ "\n",
+ " # 검증 손실을 로그에 기록합니다\n",
+ " self.log('val_loss', loss, on_epoch=True, prog_bar=True)\n",
+ " return loss"
+ ],
+ "metadata": {
+ "id": "-VkAg3TIyLYp"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "import matplotlib.pyplot as plt\n",
+ "\n",
+ "# 저장된 모델 체크포인트의 경로를 설정합니다.\n",
+ "model_checkpoint_path = \"./saved_model.ckpt\"\n",
+ "\n",
+ "# trainer.callback_metrics에서 훈련 및 검증 손실에 접근합니다.\n",
+ "train_losses = trainer.callback_metrics['train_loss']\n",
+ "val_losses = trainer.callback_metrics['val_loss']\n",
+ "\n",
+ "# 플로팅\n",
+ "plt.plot(train_losses, label='Train Loss')\n",
+ "plt.plot(val_losses, label='Validation Loss')\n",
+ "plt.xlabel('Epoch')\n",
+ "plt.ylabel('Loss')\n",
+ "plt.legend()\n",
+ "plt.show()"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 256
+ },
+ "id": "1E7rdgSHp7z8",
+ "outputId": "f2f8130a-6f85-4039-9c9d-5701d443d970"
+ },
+ "execution_count": null,
+ "outputs": [
+ {
+ "output_type": "error",
+ "ename": "KeyError",
+ "evalue": "ignored",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)",
+ "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;31m# trainer.callback_metrics에서 훈련 및 검증 손실에 접근합니다.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0;31m#train_losses = trainer.callback_metrics['train_loss']\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 8\u001b[0;31m \u001b[0mval_losses\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcallback_metrics\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'val_loss'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 9\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0;31m# 플로팅\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;31mKeyError\u001b[0m: 'val_loss'"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# 7. Test\n",
+ "테스트는 input이 고정이여서 정확도 체크 용도로 쓰면 될듯합니다"
+ ],
+ "metadata": {
+ "id": "abpCIOfZw4H3"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "input_text = data_module.dataset[\"test\"][40][args.input_key]\n",
+ "\n",
+ "model_inputs = args.tokenizer(\n",
+ " input_text,\n",
+ " max_length=args.max_input_len,\n",
+ " padding=True,\n",
+ " truncation=True,\n",
+ " return_tensors='pt',\n",
+ " )\n",
+ "model_inputs = {k: v.to(device) for k,v in model_inputs.items()}"
+ ],
+ "metadata": {
+ "id": "7U5DbY8YyZn8"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "model.model = model.model.to(device)\n",
+ "pr_seqs = model.model.generate(model_inputs[\"input_ids\"], max_length=args.max_target_len)\n",
+ "prs = args.tokenizer.batch_decode(pr_seqs, skip_special_tokens=True)\n",
+ "print(f\"Input\\n {input_text}\\n\\n\")\n",
+ "print(f\"Prediction\\n {prs}\")"
+ ],
+ "metadata": {
+ "id": "SiwJ7HoqyeMm",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "outputId": "cf14dd93-d6fc-4278-d99b-1916147eca7f"
+ },
+ "execution_count": null,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Input\n",
+ " 검색하실 판례를 입력해 주세요:\n",
+ "\n",
+ "\n",
+ "Prediction\n",
+ " [')무추서']\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "#8. input을 임의로 넣을 수 있는 코드 추가"
+ ],
+ "metadata": {
+ "id": "rpnw3g2qqJHC"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# 사용자가 직접 입력한 텍스트를 사용\n",
+ "user_input_text = \"검색할 판례를 작성해주세요:\"\n",
+ "\n",
+ "# 모델 입력을 설정\n",
+ "model_inputs = args.tokenizer(\n",
+ "\n",
+ " user_input_text,\n",
+ " max_length=args.max_input_len,\n",
+ " padding=True,\n",
+ " truncation=True,\n",
+ " return_tensors='pt',\n",
+ ")\n",
+ "model_inputs = {k: v.to(device) for k, v in model_inputs.items()}\n",
+ "\n",
+ "# 모델에 대한 예측 생성\n",
+ "model.model = model.model.to(device)\n",
+ "pr_seqs = model.model.generate(model_inputs[\"input_ids\"], max_length=args.max_target_len)\n",
+ "\n",
+ "# 디코딩하여 출력\n",
+ "prs = args.tokenizer.batch_decode(pr_seqs, skip_special_tokens=True)\n",
+ "\n",
+ "# 입력 및 예측 출력\n",
+ "print(f\"Input\\n {user_input_text}\\n\\n\")\n",
+ "print(f\"Prediction\\n {prs}\")"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "HaVNMPYHoWC9",
+ "outputId": "23743326-4f4e-4aaf-8509-aabf237f9010"
+ },
+ "execution_count": null,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Input\n",
+ " 검색할 판례를 작성해주세요:\n",
+ "\n",
+ "\n",
+ "Prediction\n",
+ " ['손해등이용nissen신청']\n"
+ ]
+ }
+ ]
+ }
+ ]
+}
\ No newline at end of file
diff --git a/base_model2.ipynb b/base_model2.ipynb
new file mode 100644
index 0000000..4284707
--- /dev/null
+++ b/base_model2.ipynb
@@ -0,0 +1,825 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "provenance": [],
+ "include_colab_link": true
+ },
+ "kernelspec": {
+ "name": "python3",
+ "display_name": "Python 3"
+ },
+ "language_info": {
+ "name": "python"
+ },
+ "accelerator": "GPU"
+ },
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "view-in-github",
+ "colab_type": "text"
+ },
+ "source": [
+ " "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "#1. Install\n",
+ "warning은 무시해도 됩니다\n",
+ "임의로 버전을 고정한거라 뜨는 메세지 같아요"
+ ],
+ "metadata": {
+ "id": "Z6v4imMIm5gs"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "!pip install --upgrade setuptools #Python 프로젝트를 패키징하고 배포하며 설치하는 프로세스를 용이하게 하는 패키지입니다.\n",
+ "!pip install transformers==4.16.2 #자연어 처리를 위한 라이브러리로, 사전 훈련된 모델, 파인튜닝 도구 및 다양한 유틸리티를 제공합니다.\n",
+ "!pip install sentencepiece #신경망 기반 텍스트 생성 작업에 특화된 토큰화를 위한 라이브러리입니다.\n",
+ "!pip install datasets\n",
+ "!pip install rouge_score #ROUGE 스코어를 계산하기 위한 라이브러리로, 자연어 처리 작업에서 자주 사용됩니다.\n",
+ "!pip install pytorch_lightning==1.5.10 #고성능 신경망 훈련을 위한 PyTorch 래퍼입니다."
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 1000
+ },
+ "id": "USqNXkEMjdE2",
+ "outputId": "28bce6d6-9c92-4fd4-c6a6-552cac4d79d9"
+ },
+ "execution_count": null,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Requirement already satisfied: setuptools in /usr/local/lib/python3.10/dist-packages (59.5.0)\n",
+ "Collecting setuptools\n",
+ " Using cached setuptools-68.2.2-py3-none-any.whl (807 kB)\n",
+ "Installing collected packages: setuptools\n",
+ " Attempting uninstall: setuptools\n",
+ " Found existing installation: setuptools 59.5.0\n",
+ " Uninstalling setuptools-59.5.0:\n",
+ " Successfully uninstalled setuptools-59.5.0\n",
+ "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
+ "ipython 7.34.0 requires jedi>=0.16, which is not installed.\n",
+ "lida 0.0.10 requires fastapi, which is not installed.\n",
+ "lida 0.0.10 requires kaleido, which is not installed.\n",
+ "lida 0.0.10 requires python-multipart, which is not installed.\n",
+ "lida 0.0.10 requires uvicorn, which is not installed.\n",
+ "pytorch-lightning 1.5.10 requires setuptools==59.5.0, but you have setuptools 68.2.2 which is incompatible.\u001b[0m\u001b[31m\n",
+ "\u001b[0mSuccessfully installed setuptools-68.2.2\n"
+ ]
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "application/vnd.colab-display-data+json": {
+ "pip_warning": {
+ "packages": [
+ "_distutils_hack",
+ "pkg_resources",
+ "setuptools"
+ ]
+ }
+ }
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Requirement already satisfied: transformers==4.16.2 in /usr/local/lib/python3.10/dist-packages (4.16.2)\n",
+ "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers==4.16.2) (3.13.1)\n",
+ "Requirement already satisfied: huggingface-hub<1.0,>=0.1.0 in /usr/local/lib/python3.10/dist-packages (from transformers==4.16.2) (0.19.3)\n",
+ "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers==4.16.2) (1.23.5)\n",
+ "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers==4.16.2) (23.2)\n",
+ "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers==4.16.2) (6.0.1)\n",
+ "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers==4.16.2) (2023.6.3)\n",
+ "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers==4.16.2) (2.31.0)\n",
+ "Requirement already satisfied: sacremoses in /usr/local/lib/python3.10/dist-packages (from transformers==4.16.2) (0.1.1)\n",
+ "Requirement already satisfied: tokenizers!=0.11.3,>=0.10.1 in /usr/local/lib/python3.10/dist-packages (from transformers==4.16.2) (0.15.0)\n",
+ "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers==4.16.2) (4.66.1)\n",
+ "Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.1.0->transformers==4.16.2) (2023.6.0)\n",
+ "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.1.0->transformers==4.16.2) (4.5.0)\n",
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->transformers==4.16.2) (3.3.2)\n",
+ "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers==4.16.2) (3.4)\n",
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers==4.16.2) (2.0.7)\n",
+ "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers==4.16.2) (2023.7.22)\n",
+ "Requirement already satisfied: click in /usr/local/lib/python3.10/dist-packages (from sacremoses->transformers==4.16.2) (8.1.7)\n",
+ "Requirement already satisfied: joblib in /usr/local/lib/python3.10/dist-packages (from sacremoses->transformers==4.16.2) (1.3.2)\n",
+ "Requirement already satisfied: sentencepiece in /usr/local/lib/python3.10/dist-packages (0.1.99)\n",
+ "Requirement already satisfied: datasets in /usr/local/lib/python3.10/dist-packages (2.15.0)\n",
+ "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from datasets) (1.23.5)\n",
+ "Requirement already satisfied: pyarrow>=8.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (9.0.0)\n",
+ "Requirement already satisfied: pyarrow-hotfix in /usr/local/lib/python3.10/dist-packages (from datasets) (0.5)\n",
+ "Requirement already satisfied: dill<0.3.8,>=0.3.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.3.7)\n",
+ "Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets) (1.5.3)\n",
+ "Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (2.31.0)\n",
+ "Requirement already satisfied: tqdm>=4.62.1 in /usr/local/lib/python3.10/dist-packages (from datasets) (4.66.1)\n",
+ "Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets) (3.4.1)\n",
+ "Requirement already satisfied: multiprocess in /usr/local/lib/python3.10/dist-packages (from datasets) (0.70.15)\n",
+ "Requirement already satisfied: fsspec[http]<=2023.10.0,>=2023.1.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (2023.6.0)\n",
+ "Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets) (3.8.6)\n",
+ "Requirement already satisfied: huggingface-hub>=0.18.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.19.3)\n",
+ "Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from datasets) (23.2)\n",
+ "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from datasets) (6.0.1)\n",
+ "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (23.1.0)\n",
+ "Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (3.3.2)\n",
+ "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (6.0.4)\n",
+ "Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (4.0.3)\n",
+ "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.9.2)\n",
+ "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.4.0)\n",
+ "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.3.1)\n",
+ "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.18.0->datasets) (3.13.1)\n",
+ "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.18.0->datasets) (4.5.0)\n",
+ "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets) (3.4)\n",
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets) (2.0.7)\n",
+ "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets) (2023.7.22)\n",
+ "Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2.8.2)\n",
+ "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2023.3.post1)\n",
+ "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.1->pandas->datasets) (1.16.0)\n",
+ "Requirement already satisfied: rouge_score in /usr/local/lib/python3.10/dist-packages (0.1.2)\n",
+ "Requirement already satisfied: absl-py in /usr/local/lib/python3.10/dist-packages (from rouge_score) (1.4.0)\n",
+ "Requirement already satisfied: nltk in /usr/local/lib/python3.10/dist-packages (from rouge_score) (3.8.1)\n",
+ "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from rouge_score) (1.23.5)\n",
+ "Requirement already satisfied: six>=1.14.0 in /usr/local/lib/python3.10/dist-packages (from rouge_score) (1.16.0)\n",
+ "Requirement already satisfied: click in /usr/local/lib/python3.10/dist-packages (from nltk->rouge_score) (8.1.7)\n",
+ "Requirement already satisfied: joblib in /usr/local/lib/python3.10/dist-packages (from nltk->rouge_score) (1.3.2)\n",
+ "Requirement already satisfied: regex>=2021.8.3 in /usr/local/lib/python3.10/dist-packages (from nltk->rouge_score) (2023.6.3)\n",
+ "Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from nltk->rouge_score) (4.66.1)\n",
+ "Requirement already satisfied: pytorch_lightning==1.5.10 in /usr/local/lib/python3.10/dist-packages (1.5.10)\n",
+ "Requirement already satisfied: numpy>=1.17.2 in /usr/local/lib/python3.10/dist-packages (from pytorch_lightning==1.5.10) (1.23.5)\n",
+ "Requirement already satisfied: torch>=1.7.* in /usr/local/lib/python3.10/dist-packages (from pytorch_lightning==1.5.10) (2.1.0+cu118)\n",
+ "Requirement already satisfied: future>=0.17.1 in /usr/local/lib/python3.10/dist-packages (from pytorch_lightning==1.5.10) (0.18.3)\n",
+ "Requirement already satisfied: tqdm>=4.41.0 in /usr/local/lib/python3.10/dist-packages (from pytorch_lightning==1.5.10) (4.66.1)\n",
+ "Requirement already satisfied: PyYAML>=5.1 in /usr/local/lib/python3.10/dist-packages (from pytorch_lightning==1.5.10) (6.0.1)\n",
+ "Requirement already satisfied: fsspec[http]!=2021.06.0,>=2021.05.0 in /usr/local/lib/python3.10/dist-packages (from pytorch_lightning==1.5.10) (2023.6.0)\n",
+ "Requirement already satisfied: tensorboard>=2.2.0 in /usr/local/lib/python3.10/dist-packages (from pytorch_lightning==1.5.10) (2.14.1)\n",
+ "Requirement already satisfied: torchmetrics>=0.4.1 in /usr/local/lib/python3.10/dist-packages (from pytorch_lightning==1.5.10) (1.2.0)\n",
+ "Requirement already satisfied: pyDeprecate==0.3.1 in /usr/local/lib/python3.10/dist-packages (from pytorch_lightning==1.5.10) (0.3.1)\n",
+ "Requirement already satisfied: packaging>=17.0 in /usr/local/lib/python3.10/dist-packages (from pytorch_lightning==1.5.10) (23.2)\n",
+ "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from pytorch_lightning==1.5.10) (4.5.0)\n",
+ "Collecting setuptools==59.5.0 (from pytorch_lightning==1.5.10)\n",
+ " Using cached setuptools-59.5.0-py3-none-any.whl (952 kB)\n",
+ "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning==1.5.10) (2.31.0)\n",
+ "Requirement already satisfied: aiohttp!=4.0.0a0,!=4.0.0a1 in /usr/local/lib/python3.10/dist-packages (from fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning==1.5.10) (3.8.6)\n",
+ "Requirement already satisfied: absl-py>=0.4 in /usr/local/lib/python3.10/dist-packages (from tensorboard>=2.2.0->pytorch_lightning==1.5.10) (1.4.0)\n",
+ "Requirement already satisfied: grpcio>=1.48.2 in /usr/local/lib/python3.10/dist-packages (from tensorboard>=2.2.0->pytorch_lightning==1.5.10) (1.59.2)\n",
+ "Requirement already satisfied: google-auth<3,>=1.6.3 in /usr/local/lib/python3.10/dist-packages (from tensorboard>=2.2.0->pytorch_lightning==1.5.10) (2.17.3)\n",
+ "Requirement already satisfied: google-auth-oauthlib<1.1,>=0.5 in /usr/local/lib/python3.10/dist-packages (from tensorboard>=2.2.0->pytorch_lightning==1.5.10) (1.0.0)\n",
+ "Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.10/dist-packages (from tensorboard>=2.2.0->pytorch_lightning==1.5.10) (3.5.1)\n",
+ "Requirement already satisfied: protobuf>=3.19.6 in /usr/local/lib/python3.10/dist-packages (from tensorboard>=2.2.0->pytorch_lightning==1.5.10) (3.20.3)\n",
+ "Requirement already satisfied: six>1.9 in /usr/local/lib/python3.10/dist-packages (from tensorboard>=2.2.0->pytorch_lightning==1.5.10) (1.16.0)\n",
+ "Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /usr/local/lib/python3.10/dist-packages (from tensorboard>=2.2.0->pytorch_lightning==1.5.10) (0.7.2)\n",
+ "Requirement already satisfied: werkzeug>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from tensorboard>=2.2.0->pytorch_lightning==1.5.10) (3.0.1)\n",
+ "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=1.7.*->pytorch_lightning==1.5.10) (3.13.1)\n",
+ "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.7.*->pytorch_lightning==1.5.10) (1.12)\n",
+ "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.7.*->pytorch_lightning==1.5.10) (3.2.1)\n",
+ "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.7.*->pytorch_lightning==1.5.10) (3.1.2)\n",
+ "Requirement already satisfied: triton==2.1.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.7.*->pytorch_lightning==1.5.10) (2.1.0)\n",
+ "Requirement already satisfied: lightning-utilities>=0.8.0 in /usr/local/lib/python3.10/dist-packages (from torchmetrics>=0.4.1->pytorch_lightning==1.5.10) (0.9.0)\n",
+ "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning==1.5.10) (23.1.0)\n",
+ "Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning==1.5.10) (3.3.2)\n",
+ "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning==1.5.10) (6.0.4)\n",
+ "Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.10/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning==1.5.10) (4.0.3)\n",
+ "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning==1.5.10) (1.9.2)\n",
+ "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning==1.5.10) (1.4.0)\n",
+ "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning==1.5.10) (1.3.1)\n",
+ "Requirement already satisfied: cachetools<6.0,>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from google-auth<3,>=1.6.3->tensorboard>=2.2.0->pytorch_lightning==1.5.10) (5.3.2)\n",
+ "Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.10/dist-packages (from google-auth<3,>=1.6.3->tensorboard>=2.2.0->pytorch_lightning==1.5.10) (0.3.0)\n",
+ "Requirement already satisfied: rsa<5,>=3.1.4 in /usr/local/lib/python3.10/dist-packages (from google-auth<3,>=1.6.3->tensorboard>=2.2.0->pytorch_lightning==1.5.10) (4.9)\n",
+ "Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.10/dist-packages (from google-auth-oauthlib<1.1,>=0.5->tensorboard>=2.2.0->pytorch_lightning==1.5.10) (1.3.1)\n",
+ "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning==1.5.10) (3.4)\n",
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning==1.5.10) (2.0.7)\n",
+ "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning==1.5.10) (2023.7.22)\n",
+ "Requirement already satisfied: MarkupSafe>=2.1.1 in /usr/local/lib/python3.10/dist-packages (from werkzeug>=1.0.1->tensorboard>=2.2.0->pytorch_lightning==1.5.10) (2.1.3)\n",
+ "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.7.*->pytorch_lightning==1.5.10) (1.3.0)\n",
+ "Requirement already satisfied: pyasn1<0.6.0,>=0.4.6 in /usr/local/lib/python3.10/dist-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard>=2.2.0->pytorch_lightning==1.5.10) (0.5.0)\n",
+ "Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.10/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<1.1,>=0.5->tensorboard>=2.2.0->pytorch_lightning==1.5.10) (3.2.2)\n",
+ "Installing collected packages: setuptools\n",
+ " Attempting uninstall: setuptools\n",
+ " Found existing installation: setuptools 68.2.2\n",
+ " Uninstalling setuptools-68.2.2:\n",
+ " Successfully uninstalled setuptools-68.2.2\n",
+ "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
+ "ipython 7.34.0 requires jedi>=0.16, which is not installed.\n",
+ "lida 0.0.10 requires fastapi, which is not installed.\n",
+ "lida 0.0.10 requires kaleido, which is not installed.\n",
+ "lida 0.0.10 requires python-multipart, which is not installed.\n",
+ "lida 0.0.10 requires uvicorn, which is not installed.\n",
+ "arviz 0.15.1 requires setuptools>=60.0.0, but you have setuptools 59.5.0 which is incompatible.\n",
+ "cvxpy 1.3.2 requires setuptools>65.5.1, but you have setuptools 59.5.0 which is incompatible.\u001b[0m\u001b[31m\n",
+ "\u001b[0mSuccessfully installed setuptools-59.5.0\n"
+ ]
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "application/vnd.colab-display-data+json": {
+ "pip_warning": {
+ "packages": [
+ "_distutils_hack",
+ "pkg_resources",
+ "setuptools"
+ ]
+ }
+ }
+ },
+ "metadata": {}
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "!git clone https://github.com/lbox-kr/lbox_open.git --branch v0.1\n",
+ "%cd lbox_open"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "q3YgO47Hxz9f",
+ "outputId": "02241cc5-5518-45e8-ecfd-13dcc61affed"
+ },
+ "execution_count": null,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Cloning into 'lbox_open'...\n",
+ "remote: Enumerating objects: 266, done.\u001b[K\n",
+ "remote: Counting objects: 100% (266/266), done.\u001b[K\n",
+ "remote: Compressing objects: 100% (161/161), done.\u001b[K\n",
+ "remote: Total 266 (delta 128), reused 210 (delta 94), pack-reused 0\u001b[K\n",
+ "Receiving objects: 100% (266/266), 86.64 KiB | 677.00 KiB/s, done.\n",
+ "Resolving deltas: 100% (128/128), done.\n",
+ "Note: switching to '5dd718fe2424d84bc73a78890b357b429c2f9cb0'.\n",
+ "\n",
+ "You are in 'detached HEAD' state. You can look around, make experimental\n",
+ "changes and commit them, and you can discard any commits you make in this\n",
+ "state without impacting any branches by switching back to a branch.\n",
+ "\n",
+ "If you want to create a new branch to retain commits you create, you may\n",
+ "do so (now or later) by using -c with the switch command. Example:\n",
+ "\n",
+ " git switch -c \n",
+ "\n",
+ "Or undo this operation with:\n",
+ "\n",
+ " git switch -\n",
+ "\n",
+ "Turn off this advice by setting config variable advice.detachedHead to false\n",
+ "\n",
+ "/content/lbox_open\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "import os\n",
+ "from argparse import Namespace\n",
+ "\n",
+ "import torch\n",
+ "import transformers\n",
+ "import pytorch_lightning as pl\n",
+ "\n",
+ "from lcube.data_module.data_lbox_open import LBoxOpenDataModule\n",
+ "from lcube.model.model_baseline import SeqToSeqBaseline\n",
+ "\n",
+ "device = 'cuda' if torch.cuda.is_available() else \"cpu\""
+ ],
+ "metadata": {
+ "id": "JM_q343LkFbq"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "#2.Parameter 생성"
+ ],
+ "metadata": {
+ "id": "gKEO97qomg3U"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "args = Namespace()\n",
+ "# dataset\n",
+ "args.dataset_card = \"lbox/lbox_open\" #데이터셋 카드를 지정하는데, 현재 \"lbox/lbox_open\"으로 설정되어 있는 것으로 보입니다.\n",
+ "args.task = \"casename_classification\" # 작업 유형을 지정하는데, 케이스 이름 분류, 법률 분류 또는 요약의 옵션\n",
+ "# args.task = \"statute_classification\"\n",
+ "# args.task = \"summarization\"\n",
+ "\n",
+ "\n",
+ "if args.task in [\"casename_classification\", \"statute_classification\"]:\n",
+ " args.input_key = \"facts\"\n",
+ "\n",
+ " # model\n",
+ " args.model_card = \"google/mt5-small\" #모델 카드를 \"google/mt5-small\"으로 설정\n",
+ " args.max_input_len = 512\n",
+ " args.max_target_len = 64\n",
+ "\n",
+ " # train\n",
+ " args.max_epochs = 10\n",
+ " args.learning_rate = 2e-4\n",
+ " args.batch_size = 8\n",
+ " args.batch_size_eval = 2 * args.batch_size\n",
+ " args.accumulate_grad_batches = 1\n",
+ " args.validation_metric = \"exact_match\"\n",
+ "\n",
+ "elif args.task == \"summarization\":\n",
+ " args.input_key = \"precedent\"\n",
+ "\n",
+ " # model\n",
+ " args.model_card = \"google/mt5-small\"\n",
+ " args.max_input_len = 768\n",
+ " args.max_target_len = 512\n",
+ "\n",
+ " # train\n",
+ " args.max_epochs = 10\n",
+ " args.learning_rate = 2e-4\n",
+ " args.batch_size = 1\n",
+ " args.batch_size_eval = 2 * args.batch_size\n",
+ " args.accumulate_grad_batches = 8\n",
+ " args.validation_metric = \"rougeL\"\n",
+ "\n",
+ "else:\n",
+ " raise ValueError\n",
+ "\n",
+ "\n",
+ "args.tokenizer = transformers.MT5TokenizerFast.from_pretrained(args.model_card)\n",
+ "pl.seed_everything(seed=1, workers=False)"
+ ],
+ "metadata": {
+ "id": "CV56eKVyn8fr",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "outputId": "9da131d7-148d-4c24-a0c9-21127c0b2842"
+ },
+ "execution_count": null,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n",
+ "INFO:pytorch_lightning.utilities.seed:Global seed set to 1\n"
+ ]
+ },
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "1"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 27
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# 3. Datasets 불러오기"
+ ],
+ "metadata": {
+ "id": "_iC4_4__oPca"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "data_module = LBoxOpenDataModule(\n",
+ " args.dataset_card,\n",
+ " args.task,\n",
+ " args.tokenizer,\n",
+ " args.max_input_len,\n",
+ " args.max_target_len,\n",
+ " args.batch_size,\n",
+ " args.batch_size_eval,\n",
+ ")"
+ ],
+ "metadata": {
+ "id": "uyerxm0woUbl"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# 4. 모델 준비\n",
+ "MT5ForConditionalGeneration"
+ ],
+ "metadata": {
+ "id": "8DpPA9wPogHW"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "backbone = transformers.MT5ForConditionalGeneration.from_pretrained(args.model_card)\n",
+ "model = SeqToSeqBaseline(\n",
+ " args.task,\n",
+ " backbone,\n",
+ " args.tokenizer,\n",
+ " args.learning_rate,\n",
+ " args.max_target_len,\n",
+ " args.validation_metric\n",
+ "\n",
+ ")"
+ ],
+ "metadata": {
+ "id": "P_Pz746Jol4S"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# 5. 학습 준비\n",
+ "epoch 5~10 정도 해야할듯 합니다\n",
+ "\n",
+ "\n"
+ ],
+ "metadata": {
+ "id": "_2q1EoVjo1qr"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## 성능 검증 그래프화"
+ ],
+ "metadata": {
+ "id": "O549_bnI3wa9"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "import matplotlib.pyplot as plt\n",
+ "from pytorch_lightning import LightningModule, Trainer, callbacks\n",
+ "\n",
+ "class PlotLossesCallback(callbacks.Callback):\n",
+ " def __init__(self):\n",
+ " self.train_losses = []\n",
+ " self.val_losses = []\n",
+ "\n",
+ " def on_epoch_end(self, trainer, pl_module):\n",
+ " # 에폭 종료 시 호출되며 훈련 및 검증 손실을 기록합니다.\n",
+ " self.train_losses.append(trainer.callback_metrics.get(\"train_loss\"))\n",
+ " self.val_losses.append(trainer.callback_metrics.get(\"val_loss\"))\n",
+ "\n",
+ " # 손실 그래프를 업데이트합니다.\n",
+ " self.plot_losses()\n",
+ "\n",
+ " def plot_losses(self):\n",
+ " plt.figure(figsize=(10, 5))\n",
+ " plt.plot(self.train_losses, label='Training Loss', color='blue')\n",
+ " plt.plot(self.val_losses, label='Validation Loss', color='red')\n",
+ " plt.xlabel('Epochs')\n",
+ " plt.ylabel('Loss')\n",
+ " plt.title('Training and Validation Loss')\n",
+ " plt.legend()\n",
+ " plt.show()\n",
+ "\n",
+ "class YourLightningModule(LightningModule):\n",
+ " def validation_step(self, batch, batch_idx):\n",
+ " # 여기에 검증 단계 논리를 작성하세요\n",
+ " outputs = self.forward(batch)\n",
+ "\n",
+ " # 예측과 정답 간의 손실 계산\n",
+ " loss = your_loss_function(outputs, batch['target']) # 적절한 손실 함수를 사용해야 합니다.\n",
+ "\n",
+ " # 검증 손실을 로그에 기록\n",
+ " self.log('val_loss', loss, on_epoch=True, prog_bar=True)\n",
+ "\n",
+ " return outputs\n",
+ "\n",
+ "# ModelCheckpoint 콜백 및 그래프화 콜백을 설정합니다.\n",
+ "model_checkpoint = callbacks.ModelCheckpoint(monitor='val_loss', dirpath='./saved/0/{args.task}', save_top_k=1, mode='max')\n",
+ "plot_losses_callback = PlotLossesCallback()"
+ ],
+ "metadata": {
+ "id": "CsqnzSaV3vbe"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "callbacks = pl.callbacks.ModelCheckpoint(\n",
+ " monitor=args.validation_metric,\n",
+ " dirpath=f\"./saved/0/{args.task}\",\n",
+ " save_top_k=1,\n",
+ " mode=\"max\",\n",
+ ")\n",
+ "trainer = pl.Trainer(\n",
+ " max_epochs = args.max_epochs,\n",
+ " gpus=torch.cuda.device_count(),\n",
+ " accumulate_grad_batches=args.accumulate_grad_batches,\n",
+ " fast_dev_run=not True,\n",
+ " callbacks=callbacks,\n",
+ ")"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "SCsu57OO9riy",
+ "outputId": "87d43b7b-173b-4eaf-9db8-96eb61016578"
+ },
+ "execution_count": null,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "INFO:pytorch_lightning.utilities.distributed:GPU available: True, used: True\n",
+ "INFO:pytorch_lightning.utilities.distributed:TPU available: False, using: 0 TPU cores\n",
+ "INFO:pytorch_lightning.utilities.distributed:IPU available: False, using: 0 IPUs\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "#trainer = pl.Trainer(\n",
+ " # max_epochs=1,\n",
+ " # gpus=torch.cuda.device_count(),\n",
+ " # accumulate_grad_batches=args.accumulate_grad_batches,\n",
+ " # fast_dev_run=not True,\n",
+ " # callbacks=[model_checkpoint, plot_losses_callback]\n",
+ ")\n"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "czojZOEa0G-R",
+ "outputId": "a8cd74ed-413e-4438-89b3-ac4e97240dd3"
+ },
+ "execution_count": null,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "INFO:pytorch_lightning.utilities.distributed:GPU available: True, used: True\n",
+ "INFO:pytorch_lightning.utilities.distributed:TPU available: False, using: 0 TPU cores\n",
+ "INFO:pytorch_lightning.utilities.distributed:IPU available: False, using: 0 IPUs\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# PyTorch Lightning을 사용하여 모델을 훈련하고 지정된\n",
+ "# 검증 메트릭에 대한 최고 성능을 갖는 모델을 저장하기 위해\n",
+ "# ModelCheckpoint 콜백을 설정합니다.\n",
+ "#callbacks = pl.callbacks.ModelCheckpoint(\n",
+ " #monitor=args.validation_metric, # 검증 메트릭을 모니터링하여 최고 성능 모델을 찾습니다.\n",
+ " # monitor='val_loss',\n",
+ " # dirpath=f\"./saved/0/{args.task}\", # 최고 성능 모델이 저장될 디렉토리입니다.\n",
+ " # save_top_k=1, # 최고의 모델 하나만 저장합니다.\n",
+ " # mode=\"max\", # 모니터링 메트릭의 최대 값을 가진 모델을 저장합니다.\n",
+ "#)\n",
+ "\n",
+ "# PyTorch Lightning Trainer 인스턴스를 생성합니다.\n",
+ "#trainer = pl.Trainer(\n",
+ " # max_epochs= 2, # 최대 에폭 수는 1입니다. 더 많은 에폭으로 훈련하려면 10으로 주석 처리된 줄을 사용할 수 있습니다.\n",
+ " #max_epochs = 10,\n",
+ " # gpus=torch.cuda.device_count(), # 사용할 GPU 수를 설정합니다.\n",
+ " # accumulate_grad_batches=args.accumulate_grad_batches, # 그래디언트 누적 배치 수를 설정합니다.\n",
+ " #fast_dev_run=not True, # True로 설정하면 훈련 및 검증을 위해 하나의 배치만 실행됩니다.\n",
+ " #callbacks=callbacks, # 사용할 콜백을 설정합니다.\n",
+ "\n",
+ "#)"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "NRqnxtjOwwRH",
+ "outputId": "f6335b9c-2c8b-468a-9b18-409c618d9dc5"
+ },
+ "execution_count": null,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "INFO:pytorch_lightning.utilities.distributed:GPU available: True, used: True\n",
+ "INFO:pytorch_lightning.utilities.distributed:TPU available: False, using: 0 TPU cores\n",
+ "INFO:pytorch_lightning.utilities.distributed:IPU available: False, using: 0 IPUs\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# 6. 학습\n",
+ "평균 1회당 9분 정도 소요"
+ ],
+ "metadata": {
+ "id": "oTAV9ONrh1P3"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "trainer.fit(model, data_module)"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 522
+ },
+ "id": "JDJ92SEth4CY",
+ "outputId": "263dcab9-77cf-4fb8-87c2-4425dc633d07"
+ },
+ "execution_count": null,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/core/datamodule.py:469: LightningDeprecationWarning: DataModule.setup has already been called, so it will not be called again. In v1.6 this behavior will change to always call DataModule.setup.\n",
+ " rank_zero_deprecation(\n",
+ "INFO:pytorch_lightning.accelerators.gpu:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n"
+ ]
+ },
+ {
+ "output_type": "error",
+ "ename": "OutOfMemoryError",
+ "evalue": "ignored",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mOutOfMemoryError\u001b[0m Traceback (most recent call last)",
+ "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdata_module\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
+ "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, train_dataloader, ckpt_path)\u001b[0m\n\u001b[1;32m 738\u001b[0m )\n\u001b[1;32m 739\u001b[0m \u001b[0mtrain_dataloaders\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrain_dataloader\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 740\u001b[0;31m self._call_and_handle_interrupt(\n\u001b[0m\u001b[1;32m 741\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_fit_impl\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain_dataloaders\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mval_dataloaders\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdatamodule\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mckpt_path\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 742\u001b[0m )\n",
+ "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py\u001b[0m in \u001b[0;36m_call_and_handle_interrupt\u001b[0;34m(self, trainer_fn, *args, **kwargs)\u001b[0m\n\u001b[1;32m 683\u001b[0m \"\"\"\n\u001b[1;32m 684\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 685\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mtrainer_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 686\u001b[0m \u001b[0;31m# TODO: treat KeyboardInterrupt as BaseException (delete the code below) in v1.7\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 687\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mKeyboardInterrupt\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mexception\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py\u001b[0m in \u001b[0;36m_fit_impl\u001b[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)\u001b[0m\n\u001b[1;32m 775\u001b[0m \u001b[0;31m# TODO: ckpt_path only in v1.7\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 776\u001b[0m \u001b[0mckpt_path\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mckpt_path\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mresume_from_checkpoint\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 777\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_run\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mckpt_path\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mckpt_path\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 778\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 779\u001b[0m \u001b[0;32massert\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstate\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstopped\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py\u001b[0m in \u001b[0;36m_run\u001b[0;34m(self, model, ckpt_path)\u001b[0m\n\u001b[1;32m 1143\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1144\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_configure_sharded_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# allow user to setup in model sharded environment\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1145\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maccelerator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msetup\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1146\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1147\u001b[0m \u001b[0;31m# ----------------------------\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/pytorch_lightning/accelerators/gpu.py\u001b[0m in \u001b[0;36msetup\u001b[0;34m(self, trainer)\u001b[0m\n\u001b[1;32m 44\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0msetup\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrainer\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m\"pl.Trainer\"\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 45\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_nvidia_flags\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlocal_rank\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 46\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msetup\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 47\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 48\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mon_train_start\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/pytorch_lightning/accelerators/accelerator.py\u001b[0m in \u001b[0;36msetup\u001b[0;34m(self, trainer)\u001b[0m\n\u001b[1;32m 89\u001b[0m \u001b[0mtrainer\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mthe\u001b[0m \u001b[0mtrainer\u001b[0m \u001b[0minstance\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 90\u001b[0m \"\"\"\n\u001b[0;32m---> 91\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msetup_training_type_plugin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 92\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtraining_type_plugin\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msetup_optimizers_in_pre_dispatch\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 93\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msetup_optimizers\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/pytorch_lightning/accelerators/accelerator.py\u001b[0m in \u001b[0;36msetup_training_type_plugin\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 361\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0msetup_training_type_plugin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 362\u001b[0m \u001b[0;34m\"\"\"Attaches the training type plugin to the accelerator.\"\"\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 363\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtraining_type_plugin\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msetup\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 364\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 365\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0msetup_precision_plugin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/pytorch_lightning/plugins/training_type/single_device.py\u001b[0m in \u001b[0;36msetup\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 69\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 70\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0msetup\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 71\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel_to_device\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 72\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 73\u001b[0m \u001b[0;34m@\u001b[0m\u001b[0mproperty\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/pytorch_lightning/plugins/training_type/single_device.py\u001b[0m in \u001b[0;36mmodel_to_device\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 66\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 67\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mmodel_to_device\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 68\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_model\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mroot_device\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 69\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 70\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0msetup\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/pytorch_lightning/core/mixins/device_dtype_mixin.py\u001b[0m in \u001b[0;36mto\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 109\u001b[0m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_C\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_nn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_parse_to\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 110\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__update_properties\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 111\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 112\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 113\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mcuda\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mOptional\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mUnion\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0;34m\"DeviceDtypeModuleMixin\"\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36mto\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1158\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_floating_point\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_complex\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_blocking\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1159\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1160\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconvert\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1161\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1162\u001b[0m def register_full_backward_pre_hook(\n",
+ "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_apply\u001b[0;34m(self, fn, recurse)\u001b[0m\n\u001b[1;32m 808\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mrecurse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 809\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchildren\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 810\u001b[0;31m \u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 811\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 812\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mcompute_should_use_set_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtensor_applied\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_apply\u001b[0;34m(self, fn, recurse)\u001b[0m\n\u001b[1;32m 808\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mrecurse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 809\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchildren\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 810\u001b[0;31m \u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 811\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 812\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mcompute_should_use_set_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtensor_applied\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_apply\u001b[0;34m(self, fn, recurse)\u001b[0m\n\u001b[1;32m 831\u001b[0m \u001b[0;31m# `with torch.no_grad():`\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 832\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mno_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 833\u001b[0;31m \u001b[0mparam_applied\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparam\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 834\u001b[0m \u001b[0mshould_use_set_data\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcompute_should_use_set_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparam\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparam_applied\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 835\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mshould_use_set_data\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36mconvert\u001b[0;34m(t)\u001b[0m\n\u001b[1;32m 1156\u001b[0m return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None,\n\u001b[1;32m 1157\u001b[0m non_blocking, memory_format=convert_to_format)\n\u001b[0;32m-> 1158\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_floating_point\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_complex\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_blocking\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1159\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1160\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconvert\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;31mOutOfMemoryError\u001b[0m: CUDA out of memory. Tried to allocate 490.00 MiB. GPU 0 has a total capacty of 14.75 GiB of which 42.81 MiB is free. Process 6215 has 14.70 GiB memory in use. Of the allocated memory 13.43 GiB is allocated by PyTorch, and 262.09 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# 7. Test\n",
+ "테스트는 input이 고정이여서 정확도 체크 용도로 쓰면 될듯합니다"
+ ],
+ "metadata": {
+ "id": "abpCIOfZw4H3"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "input_text = data_module.dataset[\"test\"][40][args.input_key]\n",
+ "\n",
+ "model_inputs = args.tokenizer(\n",
+ " input_text,\n",
+ " max_length=args.max_input_len,\n",
+ " padding=True,\n",
+ " truncation=True,\n",
+ " return_tensors='pt',\n",
+ " )\n",
+ "model_inputs = {k: v.to(device) for k,v in model_inputs.items()}"
+ ],
+ "metadata": {
+ "id": "7U5DbY8YyZn8"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "model.model = model.model.to(device)\n",
+ "pr_seqs = model.model.generate(model_inputs[\"input_ids\"], max_length=args.max_target_len)\n",
+ "prs = args.tokenizer.batch_decode(pr_seqs, skip_special_tokens=True)\n",
+ "print(f\"Input\\n {input_text}\\n\\n\")\n",
+ "print(f\"Prediction\\n {prs}\")"
+ ],
+ "metadata": {
+ "id": "SiwJ7HoqyeMm",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "outputId": "cf14dd93-d6fc-4278-d99b-1916147eca7f"
+ },
+ "execution_count": null,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Input\n",
+ " 검색하실 판례를 입력해 주세요:\n",
+ "\n",
+ "\n",
+ "Prediction\n",
+ " [')무추서']\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "#8. input을 임의로 넣을 수 있는 코드 추가"
+ ],
+ "metadata": {
+ "id": "rpnw3g2qqJHC"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "from transformers import MT5TokenizerFast, MT5ForConditionalGeneration\n",
+ "\n",
+ "# 토크나이저 및 모델 초기화\n",
+ "tokenizer = MT5TokenizerFast.from_pretrained(args.model_card)\n",
+ "model = MT5ForConditionalGeneration.from_pretrained(args.model_card)\n",
+ "\n",
+ "# 저장된 모델 불러오기\n",
+ "checkpoint_path = \"./saved/0/{args.task}/epoch=2.ckpt\" # 여기에 실제로 저장된 모델의 경로를 지정하세요.\n",
+ "model.load_state_dict(torch.load(checkpoint_path)[\"state_dict\"])\n",
+ "\n",
+ "# 검색 기능 함수 정의\n",
+ "def search(query):\n",
+ " inputs = tokenizer.encode(query, return_tensors=\"pt\", max_length=args.max_input_len, truncation=True)\n",
+ " outputs = model.generate(inputs, max_length=args.max_target_len, num_beams=4, length_penalty=2.0, early_stopping=True)\n",
+ " result = tokenizer.decode(outputs[0], skip_special_tokens=True)\n",
+ " return result\n",
+ "\n",
+ "# 검색어 입력 및 검색 수행\n",
+ "search_query = \"검색하실 판례를 작성해주세요:\"\n",
+ "search_result = search(search_query)\n",
+ "print(\"위법 법률:\", search_result)"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "HaVNMPYHoWC9",
+ "outputId": "23743326-4f4e-4aaf-8509-aabf237f9010"
+ },
+ "execution_count": null,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Input\n",
+ " 검색할 판례를 작성해주세요:\n",
+ "\n",
+ "\n",
+ "Prediction\n",
+ " ['손해등이용nissen신청']\n"
+ ]
+ }
+ ]
+ }
+ ]
+}
\ No newline at end of file
| | |