diff --git a/notebooks/GenAI/AWS_Bedrock_Intro.ipynb b/notebooks/GenAI/AWS_Bedrock_Intro.ipynb index 520a5f9..58979a2 100644 --- a/notebooks/GenAI/AWS_Bedrock_Intro.ipynb +++ b/notebooks/GenAI/AWS_Bedrock_Intro.ipynb @@ -71,7 +71,7 @@ "tags": [] }, "source": [ - "![bedrock_overview](../../../docs/images/bedrock_page.png)" + "![bedrock_overview](../../docs/images/bedrock_page.png)" ] }, { @@ -85,7 +85,9 @@ "- Titan Embeddings G1 - Text\n", "- Claude (need to submit a use case)\n", "\n", - "You are only charged when you use the model unlike jumpstart where you turn on an endpoint and are charged as long as its running, more detail can be found [here](https://aws.amazon.com/bedrock/pricing/)." + "You are only charged when you use the model unlike jumpstart where you turn on an endpoint and are charged as long as its running, more detail can be found [here](https://aws.amazon.com/bedrock/pricing/).\n", + "\n", + "**Warning:** Do not forget to opt out of sharing your data before requesting access to this models!" ] }, { @@ -95,7 +97,7 @@ "tags": [] }, "source": [ - "![bedrock_models](../../../docs/images/bedrock_model_access.png)" + "![bedrock_models](../../docs/images/bedrock_model_access.png)" ] }, { @@ -127,7 +129,7 @@ "id": "8f8c3521-ca5a-4c74-bc44-6ee18261db97", "metadata": {}, "source": [ - "![bedrock_chat_playground_1](../../../docs/images/bedrock_chat_playground_1.png)" + "![bedrock_chat_playground_1](../../docs/images/bedrock_chat_playground_1.png)" ] }, { @@ -143,7 +145,7 @@ "id": "51e613aa-de17-4344-91fe-e65497364a2d", "metadata": {}, "source": [ - "![bedrock_chat_playground_2](../../../docs/images/bedrock_chat_playground_2.png)" + "![bedrock_chat_playground_2](../../docs/images/bedrock_chat_playground_2.png)" ] }, { @@ -162,7 +164,7 @@ "id": "1210bcc7-7084-414e-bc67-ec5a712c54fa", "metadata": {}, "source": [ - "![bedrock_chat_playground_4](../../../docs/images/bedrock_chat_playground_4.png)" + "![bedrock_chat_playground_4](../../docs/images/bedrock_chat_playground_4.png)" ] }, { @@ -201,7 +203,7 @@ "tags": [] }, "source": [ - "![bedrock_chat_playground_3](../../../docs/images/bedrock_chat_playground_3.png)" + "![bedrock_chat_playground_3](../../docs/images/bedrock_chat_playground_3.png)" ] }, { @@ -227,7 +229,7 @@ "source": [ "### Load in data\n", "\n", - "Below we are creating a bucket to store our articles, then we will copy the metadata from the PubMed bucket and parse that to only list the path of the first 100 articles within that bucket. The last step will be to copy those articles to our bucket." + "Below we are creating a bucket to store our articles, then we will copy the metadata from the PubMed bucket and parse that to only list the path of the first 50 articles within that bucket. The last step will be to copy those articles to our bucket." ] }, { @@ -239,12 +241,12 @@ "source": [ "#make bucket, dont forget to add your own bucket name\n", "bucket = 'pubmed-chat-docs'\n", - "!aws s3 mb s3://{bucket}\n" + "!aws s3 mb s3://{bucket}" ] }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 1, "id": "cbfffd09-7983-419b-a081-7f1cd7e98736", "metadata": { "tags": [] @@ -274,10 +276,10 @@ "import pandas as pd\n", "import os\n", "df = pd.read_csv('oa_comm.filelist.csv')\n", - "#first 100 files\n", - "first_100=df[0:100]\n", + "#first 50 files\n", + "first_50=df[0:50]\n", "#save new metadata\n", - "first_100.to_csv('oa_comm.filelist_100.csv', index=False)" + "first_50.to_csv('oa_comm.filelist_100.csv', index=False)" ] }, { @@ -289,7 +291,7 @@ "source": [ "import os\n", "#gather path to files in bucket\n", - "for i in first_100['Key']:\n", + "for i in first_50['Key']:\n", " os.system(f'aws s3 cp s3://pmc-oa-opendata/{i} s3://{bucket}/docs/ --sse')" ] }, @@ -316,7 +318,7 @@ "tags": [] }, "source": [ - "![bedrock_knowledgebase_1](../../../docs/images/bedrock_knowledgebase_1.png)" + "![bedrock_knowledgebase_1](../../docs/images/bedrock_knowledgebase_1.png)" ] }, { @@ -332,7 +334,7 @@ "id": "c4a53f7b-da0f-4ef1-bc78-fed8f411fb7c", "metadata": {}, "source": [ - "![bedrock_knowledgebase_2](../../../docs/images/bedrock_knowledgebase_2.png)" + "![bedrock_knowledgebase_2](../../docs/images/bedrock_knowledgebase_2.png)" ] }, { @@ -348,7 +350,7 @@ "id": "b815514d-d9a6-4a2b-b46e-b8b928451138", "metadata": {}, "source": [ - "![bedrock_knowledgebase_3](../../../docs/images/bedrock_knowledgebase_3.png)" + "![bedrock_knowledgebase_3](../../docs/images/bedrock_knowledgebase_3.png)" ] }, { @@ -366,7 +368,7 @@ "id": "4a2106fd-e627-4620-8438-9cb4fcf04105", "metadata": {}, "source": [ - "![bedrock_knowledgebase_4](../../../docs/images/bedrock_knowledgebase_4.png)" + "![bedrock_knowledgebase_4](../../docs/images/bedrock_knowledgebase_4.png)" ] }, { @@ -382,7 +384,7 @@ "id": "68c5e7e8-1247-4f09-8376-8ac4d156873e", "metadata": {}, "source": [ - "![bedrock_knowledgebase_5](../../../docs/images/bedrock_knowledgebase_5.png)" + "![bedrock_knowledgebase_5](../../docs/images/bedrock_knowledgebase_5.png)" ] }, { @@ -400,7 +402,7 @@ "tags": [] }, "source": [ - "![bedrock_knowledgebase_6](../../../docs/images/bedrock_knowledgebase_6.png)" + "![bedrock_knowledgebase_6](../../docs/images/bedrock_knowledgebase_6.png)" ] }, { @@ -418,7 +420,7 @@ "id": "cda72990-bf03-49c8-8520-44f83ff83652", "metadata": {}, "source": [ - "![bedrock_knowledgebase_7](../../../docs/images/bedrock_knowledgebase_7.png)" + "![bedrock_knowledgebase_7](../../docs/images/bedrock_knowledgebase_7.png)" ] }, { @@ -434,7 +436,7 @@ "id": "2d672fa0-4ecc-48df-afea-1e657e04bc2c", "metadata": {}, "source": [ - "![bedrock_knowledgebase_8](../../../docs/images/bedrock_knowledgebase_8.png)" + "![bedrock_knowledgebase_8](../../docs/images/bedrock_knowledgebase_8.png)" ] }, { @@ -450,7 +452,7 @@ "id": "5e398278-0ce3-4b26-b48e-fea1d947058d", "metadata": {}, "source": [ - "![bedrock_knowledgebase_9](../../../docs/images/bedrock_knowledgebase_9.png)" + "![bedrock_knowledgebase_9](../../docs/images/bedrock_knowledgebase_9.png)" ] }, { @@ -466,7 +468,7 @@ "id": "2c29c470-0870-4ff8-b1c3-3f0a6c812e24", "metadata": {}, "source": [ - "![bedrock_knowledgebase_10](../../../docs/images/bedrock_knowledgebase_10.png)" + "![bedrock_knowledgebase_10](../../docs/images/bedrock_knowledgebase_10.png)" ] }, { @@ -490,7 +492,7 @@ "id": "5e0e8aa3-84db-40a8-9b62-579b3d01a8e5", "metadata": {}, "source": [ - "![bedrock_agents_1](../../../docs/images/bedrock_agents_1.png)" + "![bedrock_agents_1](../../docs/images/bedrock_agents_1.png)" ] }, { @@ -506,7 +508,7 @@ "id": "01633367-7cbe-4205-954c-3e7577fa2fe4", "metadata": {}, "source": [ - "![bedrock_agents_2](../../../docs/images/bedrock_agents_2.png)" + "![bedrock_agents_2](../../docs/images/bedrock_agents_2.png)" ] }, { @@ -522,7 +524,7 @@ "id": "1d1c6314-09ab-437e-b4ea-7d409d136094", "metadata": {}, "source": [ - "![bedrock_agents_3](../../../docs/images/bedrock_agents_3.png)" + "![bedrock_agents_3](../../docs/images/bedrock_agents_3.png)" ] }, { @@ -538,7 +540,7 @@ "id": "af4b2535-cca3-4068-9b80-3fbdd17a5b6b", "metadata": {}, "source": [ - "![bedrock_agents_4](../../../docs/images/bedrock_agents_4.png)" + "![bedrock_agents_4](../../docs/images/bedrock_agents_4.png)" ] }, { @@ -554,7 +556,7 @@ "id": "3a178345-22b5-446c-8e04-ad9e50584b54", "metadata": {}, "source": [ - "![bedrock_agents_5](../../../docs/images/bedrock_agents_5.png)" + "![bedrock_agents_5](../../docs/images/bedrock_agents_5.png)" ] }, { @@ -570,7 +572,7 @@ "id": "4f7840b2-0fcc-4bfe-8f0a-f253b64f9f7d", "metadata": {}, "source": [ - "![bedrock_agents_6](../../../docs/images/bedrock_agents_6.png)" + "![bedrock_agents_6](../../docs/images/bedrock_agents_6.png)" ] }, { @@ -586,7 +588,7 @@ "id": "6a184afa-7f06-4439-bd7e-5a8e7d5c9395", "metadata": {}, "source": [ - "![bedrock_agents_7](../../../docs/images/bedrock_agents_7.png)" + "![bedrock_agents_7](../../docs/images/bedrock_agents_7.png)" ] }, { diff --git a/notebooks/GenAI/AWS_GenAI_Huggingface.ipynb b/notebooks/GenAI/AWS_GenAI_Huggingface.ipynb index 202e9ee..0b01109 100644 --- a/notebooks/GenAI/AWS_GenAI_Huggingface.ipynb +++ b/notebooks/GenAI/AWS_GenAI_Huggingface.ipynb @@ -165,6 +165,14 @@ "train_dataset, test_dataset = load_dataset(\"ccdv/pubmed-summarization\", split=[\"train\", \"test\"])\n" ] }, + { + "cell_type": "markdown", + "id": "3399abb1-af8f-46ee-92ea-c8344eeddd09", + "metadata": {}, + "source": [ + "## Finetuning our Model Locally" + ] + }, { "cell_type": "markdown", "id": "ed6ddff1-2636-4e3b-88ee-e3c86c584245", @@ -187,9 +195,10 @@ "outputs": [], "source": [ "from transformers import AutoModelForSeq2SeqLM, AutoTokenizer\n", + "model_name=\"google/flan-t5-small\"\n", "\n", - "model = AutoModelForSeq2SeqLM.from_pretrained(\"google/flan-t5-small\")\n", - "tokenizer = AutoTokenizer.from_pretrained(\"google/flan-t5-small\")" + "model = AutoModelForSeq2SeqLM.from_pretrained(model_name)\n", + "tokenizer = AutoTokenizer.from_pretrained(model_name)" ] }, { @@ -230,6 +239,106 @@ "test_dataset.set_format(\"torch\", columns=[\"input_ids\", \"attention_mask\", \"abstracts\"])" ] }, + { + "cell_type": "markdown", + "id": "b3ffd612-abde-4666-8c85-cc7069de2129", + "metadata": {}, + "source": [ + "The first step to training our model other than setting up our datasets is to set our **hyperparameters**. Hyperparameters depend on your training script and for this one we need to identify our model, the location of our train and test files, etc. iN this case we are using a one created by Hugging Face." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c06bef19-cc3c-476f-943c-78368e9f49e8", + "metadata": {}, + "outputs": [], + "source": [ + "from transformers import TrainingArguments\n", + "\n", + "training_args = TrainingArguments(output_dir=\"test_trainer\")" + ] + }, + { + "cell_type": "markdown", + "id": "cff31d69-9f54-4235-a377-7c5e758fbca8", + "metadata": {}, + "source": [ + "Next create setting to evaluate the models accuracy." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "24bbe62e-9140-4bef-88ae-3e5029ddb25c", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import evaluate\n", + "\n", + "metric = evaluate.load(\"accuracy\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b82caeba-2daa-4526-b67d-04f45d4a9934", + "metadata": {}, + "outputs": [], + "source": [ + "def compute_metrics(eval_pred):\n", + " logits, labels = eval_pred\n", + " predictions = np.argmax(logits, axis=-1)\n", + " return metric.compute(predictions=predictions, references=labels)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f5b50ec0-87b8-4578-96aa-e26bda9d99b8", + "metadata": {}, + "outputs": [], + "source": [ + "from transformers import TrainingArguments, Trainer\n", + "\n", + "training_args = TrainingArguments(output_dir=\"test_trainer\", evaluation_strategy=\"epoch\")" + ] + }, + { + "cell_type": "markdown", + "id": "df2225ac-8e92-4a14-a368-eebff9ead6bf", + "metadata": {}, + "source": [ + "Finally we can train our model!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e59332ae-c9e3-4a9b-9a7c-7020c87227da", + "metadata": {}, + "outputs": [], + "source": [ + "trainer = Trainer(\n", + " model=model,\n", + " args=training_args,\n", + " train_dataset=train_dataset,\n", + " eval_dataset=test_dataset,\n", + " compute_metrics=compute_metrics,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f35520bb-b6ca-4996-b87e-2fbfdcfc0dff", + "metadata": {}, + "outputs": [], + "source": [ + "trainer.train()" + ] + }, { "cell_type": "markdown", "id": "6ac841f6-c65e-4ebf-8c42-3030e2f92cb0", @@ -319,7 +428,7 @@ "id": "9204b6dc-8f6e-407e-8c68-a036a6a5b7c9", "metadata": {}, "source": [ - "### Training our Model" + "### Training our ModelFinetuning our Model via Vertex AI Training API" ] }, { @@ -603,7 +712,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.12" + "version": "3.10.13" } }, "nbformat": 4, diff --git a/notebooks/GenAI/AWS_GenAI_Jumpstart.ipynb b/notebooks/GenAI/AWS_GenAI_Jumpstart.ipynb index d477dc8..0b1eaa3 100644 --- a/notebooks/GenAI/AWS_GenAI_Jumpstart.ipynb +++ b/notebooks/GenAI/AWS_GenAI_Jumpstart.ipynb @@ -44,20 +44,14 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "id": "6cf1429a-314e-49b6-a4f7-16a3e52319af", "metadata": { "tags": [] }, "outputs": [], "source": [ - "(\n", - " model_id,\n", - " model_version,\n", - ") = (\n", - " \"meta-textgeneration-llama-2-7b-f\",\n", - " \"*\",\n", - ")" + "model_id, model_version = \"meta-textgeneration-llama-2-13b-f\", \"2.*\"" ] }, { @@ -79,7 +73,7 @@ "source": [ "from sagemaker.jumpstart.model import JumpStartModel\n", "\n", - "model = JumpStartModel(model_id=model_id)\n", + "model = JumpStartModel(model_id=model_id, model_version=model_version)\n", "predictor = model.deploy()\n" ] }, @@ -216,7 +210,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.12" + "version": "3.10.13" } }, "nbformat": 4, diff --git a/notebooks/GenAI/Pubmed_RAG_chatbot.ipynb b/notebooks/GenAI/Pubmed_RAG_chatbot.ipynb index 09aa19a..c53b053 100644 --- a/notebooks/GenAI/Pubmed_RAG_chatbot.ipynb +++ b/notebooks/GenAI/Pubmed_RAG_chatbot.ipynb @@ -39,20 +39,14 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "id": "6b51bf71-d2e5-4afc-8569-338767b43b9c", "metadata": { "tags": [] }, "outputs": [], "source": [ - "(\n", - " model_id,\n", - " model_version,\n", - ") = (\n", - " \"meta-textgeneration-llama-2-7b-f\",\n", - " \"*\",\n", - ")" + "model_id, model_version = \"meta-textgeneration-llama-2-13b-f\", \"2.*\"" ] }, { @@ -65,49 +59,17 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "id": "bf27747d-443f-47e7-9d2c-a8f5c5c6f3b8", "metadata": { "tags": [] }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages/pandas/core/computation/expressions.py:21: UserWarning: Pandas requires version '2.8.0' or newer of 'numexpr' (version '2.7.3' currently installed).\n", - " from pandas.core.computation.check import NUMEXPR_INSTALLED\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml\n", - "sagemaker.config INFO - Not applying SDK defaults from location: /home/ec2-user/.config/sagemaker/config.yaml\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "For forward compatibility, pin to model_version='2.*' in your JumpStartModel or JumpStartEstimator definitions. Note that major version upgrades may have different EULA acceptance terms and input/output signatures.\n", - "For forward compatibility, pin to model_version='2.*' in your JumpStartModel or JumpStartEstimator definitions. Note that major version upgrades may have different EULA acceptance terms and input/output signatures.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "-----------------!" - ] - } - ], + "outputs": [], "source": [ "from sagemaker.jumpstart.model import JumpStartModel\n", "\n", - "model = JumpStartModel(model_id=model_id)\n", - "predictor = model.deploy()" + "model = JumpStartModel(model_id=model_id, model_version=model_version)\n", + "predictor = model.deploy()\n" ] }, { @@ -120,7 +82,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "id": "1ad71f0d-3be5-4b03-9c1c-eb4585721fc8", "metadata": { "tags": [] @@ -226,35 +188,26 @@ "id": "93a8595a-767f-4cad-9273-62d8e2cf60d1", "metadata": {}, "source": [ - "We only want the metadata of the first 100 files." + "We only want the metadata of the first 50 files." ] }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "id": "c26b0f29-2b07-43a6-800d-4aa5e957fe52", "metadata": { "tags": [] }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages/pandas/core/computation/expressions.py:21: UserWarning: Pandas requires version '2.8.0' or newer of 'numexpr' (version '2.7.3' currently installed).\n", - " from pandas.core.computation.check import NUMEXPR_INSTALLED\n" - ] - } - ], + "outputs": [], "source": [ "#import the file as a dataframe\n", "import pandas as pd\n", "import os\n", "df = pd.read_csv('oa_comm.filelist.csv')\n", - "#first 100 files\n", - "first_100=df[0:101]\n", + "#first 50 files\n", + "first_50=df[0:50]\n", "#save new metadata\n", - "first_100.to_csv('oa_comm.filelist_100.csv', index=False)" + "first_50.to_csv('oa_comm.filelist_50.csv', index=False)" ] }, { @@ -262,7 +215,7 @@ "id": "abd1ae93-450e-4c79-83cc-ea46a1b507c1", "metadata": {}, "source": [ - "Lets look at our metadata! We can see that the bucket path to the files are under the **Key** column this is what we will use to loop through the PMC bucket and copy the first 100 files to our bucket." + "Lets look at our metadata! We can see that the bucket path to the files are under the **Key** column this is what we will use to loop through the PMC bucket and copy the first 50 files to our bucket." ] }, { @@ -272,7 +225,7 @@ "metadata": {}, "outputs": [], "source": [ - "first_100" + "first_50" ] }, { @@ -284,7 +237,7 @@ "source": [ "import os\n", "#gather path to files in bucket\n", - "for i in first_100['Key']:\n", + "for i in first_50['Key']:\n", " os.system(f'aws s3 cp s3://pmc-oa-opendata/{i} s3://{bucket}/docs/ --sse')" ] }, @@ -303,7 +256,7 @@ "metadata": {}, "outputs": [], "source": [ - "!aws s3 cp oa_comm.filelist_100.csv s3://{bucket}/docs/" + "!aws s3 cp oa_comm.filelist_50.csv s3://{bucket}/docs/" ] }, { @@ -381,12 +334,12 @@ }, "source": [ "```python\n", - "from langchain.retrievers import PubMedRetriever\n", + "from langchain_community.retrievers import PubMedRetriever\n", "from langchain.retrievers import AmazonKendraRetriever\n", - "from langchain.llms import SagemakerEndpoint\n", + "from langchain_community.llms import SagemakerEndpoint\n", + "from langchain_community.llms.sagemaker_endpoint import LLMContentHandler\n", "from langchain.chains import ConversationalRetrievalChain\n", "from langchain.prompts import PromptTemplate\n", - "from langchain.llms.sagemaker_endpoint import LLMContentHandler\n", "import sys\n", "import json\n", "import os\n", @@ -739,23 +692,12 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "id": "ba97df23-6893-438d-8a67-cb7dbf83e407", "metadata": { "tags": [] }, - "outputs": [ - { - "data": { - "text/plain": [ - "'meta-textgeneration-llama-2-7b-f-2023-11-21-20-18-40-341'" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "#retreive our endpoint id\n", "endpoint_id" @@ -789,7 +731,7 @@ "id": "80c8fb4b-e74f-4e8d-892b-0f913eff747d", "metadata": {}, "source": [ - "![PubMed Chatbot Results](../../../docs/images/PubMed_chatbot_results.png)" + "![PubMed Chatbot Results](../../docs/images/PubMed_chatbot_results.png)" ] }, { @@ -810,22 +752,10 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "id": "c307bb17-757a-4579-a0d8-698eb1bb3f2e", "metadata": {}, - "outputs": [ - { - "ename": "NameError", - "evalue": "name 'endpoint' is not defined", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[2], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m#Delete model and endpoint\u001b[39;00m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;66;03m#model.delete()\u001b[39;00m\n\u001b[0;32m----> 3\u001b[0m \u001b[43mendpoint\u001b[49m\u001b[38;5;241m.\u001b[39mdelete()\n", - "\u001b[0;31mNameError\u001b[0m: name 'endpoint' is not defined" - ] - } - ], + "outputs": [], "source": [ "#Delete model and endpoint\n", "model.delete()\n", diff --git a/notebooks/GenAI/example_scripts/langchain_chat_llama_2_zeroshot.py b/notebooks/GenAI/example_scripts/langchain_chat_llama_2_zeroshot.py index 35725d6..a81830d 100644 --- a/notebooks/GenAI/example_scripts/langchain_chat_llama_2_zeroshot.py +++ b/notebooks/GenAI/example_scripts/langchain_chat_llama_2_zeroshot.py @@ -1,12 +1,11 @@ -from langchain.retrievers import PubMedRetriever +from langchain_community.retrievers import PubMedRetriever from langchain.chains import ConversationalRetrievalChain from langchain.prompts import PromptTemplate -#from langchain import SagemakerEndpoint -from langchain.llms.sagemaker_endpoint import LLMContentHandler +from langchain_community.llms import SagemakerEndpoint +from langchain_community.llms.sagemaker_endpoint import LLMContentHandler import sys import json import os -from langchain.llms import SagemakerEndpoint class bcolors: @@ -24,7 +23,6 @@ class bcolors: def build_chain(): region = os.environ["AWS_REGION"] - #kendra_index_id = os.environ["KENDRA_INDEX_ID"] endpoint_name = os.environ["LLAMA_2_ENDPOINT"] class ContentHandler(LLMContentHandler): @@ -58,7 +56,6 @@ def transform_output(self, output: bytes) -> str: content_handler=content_handler, ) - #retriever = AmazonKendraRetriever(index_id=kendra_index_id,region_name=region) retriever= PubMedRetriever() prompt_template = """