Skip to content

Commit

Permalink
vertexai[patch]: infer project from creds (#523)
Browse files Browse the repository at this point in the history
  • Loading branch information
baskaryan authored Oct 1, 2024
2 parents 961fd31 + 4a11a30 commit b8d4a78
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 1 deletion.
1 change: 1 addition & 0 deletions .github/workflows/_integration_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ jobs:
GOOGLE_API_KEY: ${{ secrets.GOOGLE_API_KEY }}
GOOGLE_SEARCH_API_KEY: ${{ secrets.GOOGLE_SEARCH_API_KEY }}
GOOGLE_CSE_ID: ${{ secrets.GOOGLE_CSE_ID }}
GOOGLE_VERTEX_AI_WEB_CREDENTIALS: ${{ secrets.GOOGLE_VERTEX_AI_WEB_CREDENTIALS }}
run: |
make integration_tests
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/_release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ jobs:
GOOGLE_API_KEY: ${{ secrets.GOOGLE_API_KEY }}
GOOGLE_SEARCH_API_KEY: ${{ secrets.GOOGLE_SEARCH_API_KEY }}
GOOGLE_CSE_ID: ${{ secrets.GOOGLE_CSE_ID }}
GOOGLE_VERTEX_AI_WEB_CREDENTIALS: ${{ secrets.GOOGLE_VERTEX_AI_WEB_CREDENTIALS }}
run: make integration_tests
working-directory: ${{ inputs.working-directory }}

Expand Down
5 changes: 4 additions & 1 deletion libs/vertexai/langchain_google_vertexai/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,10 @@ def validate_params_base(cls, values: dict) -> Any:
@model_validator(mode="after")
def validate_project(self) -> Any:
if self.project is None:
self.project = initializer.global_config.project
if self.credentials and hasattr(self.credentials, "project_id"):
self.project = self.credentials.project_id
else:
self.project = initializer.global_config.project
return self

@property
Expand Down
11 changes: 11 additions & 0 deletions libs/vertexai/tests/integration_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import base64
import json
import os
from typing import List, Optional, cast

import pytest
Expand All @@ -11,6 +12,7 @@
Content,
Part,
)
from google.oauth2 import service_account
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
Expand Down Expand Up @@ -912,3 +914,12 @@ def test_langgraph_example() -> None:
tools=[{"function_declarations": [add_declaration, multiply_declaration]}],
)
assert isinstance(step2, AIMessage)


def test_init_from_credentials_obj() -> None:
credentials_dict = json.loads(os.environ["GOOGLE_VERTEX_AI_WEB_CREDENTIALS"])
credentials = service_account.Credentials.from_service_account_info(
credentials_dict
)
llm = ChatVertexAI(model="gemini-1.5-flash", credentials=credentials)
llm.invoke("how are you")

0 comments on commit b8d4a78

Please sign in to comment.