|
|
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
|
|
|
|
|
|
|
|
|
import logging
|
|
|
|
|
import os
|
|
|
|
|
from typing import Any, Callable, Dict, List, Mapping, Optional, Union
|
|
|
|
|
from typing import Any, Callable, Mapping
|
|
|
|
|
|
|
|
|
|
import openai
|
|
|
|
|
from langchain_core.pydantic_v1 import (
|
|
|
|
@ -36,14 +36,14 @@ class AzureOpenAI(BaseOpenAI):
|
|
|
|
|
openai = AzureOpenAI(model_name="gpt-3.5-turbo-instruct")
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
azure_endpoint: Union[str, None] = None
|
|
|
|
|
azure_endpoint: str | None = None
|
|
|
|
|
"""Your Azure endpoint, including the resource.
|
|
|
|
|
|
|
|
|
|
Automatically inferred from env var `AZURE_OPENAI_ENDPOINT` if not provided.
|
|
|
|
|
|
|
|
|
|
Example: `https://example-resource.azure.openai.com/`
|
|
|
|
|
"""
|
|
|
|
|
deployment_name: Union[str, None] = Field(
|
|
|
|
|
deployment_name: str | None = Field(
|
|
|
|
|
default=None, alias="azure_deployment"
|
|
|
|
|
)
|
|
|
|
|
"""A model deployment.
|
|
|
|
@ -53,11 +53,11 @@ class AzureOpenAI(BaseOpenAI):
|
|
|
|
|
"""
|
|
|
|
|
openai_api_version: str = Field(default="", alias="api_version")
|
|
|
|
|
"""Automatically inferred from env var `OPENAI_API_VERSION` if not provided."""
|
|
|
|
|
openai_api_key: Optional[SecretStr] = Field(
|
|
|
|
|
openai_api_key: SecretStr | None = Field(
|
|
|
|
|
default=None, alias="api_key"
|
|
|
|
|
)
|
|
|
|
|
"""Automatically inferred from env var `AZURE_OPENAI_API_KEY` if not provided."""
|
|
|
|
|
azure_ad_token: Optional[SecretStr] = None
|
|
|
|
|
azure_ad_token: SecretStr | None = None
|
|
|
|
|
"""Your Azure Active Directory token.
|
|
|
|
|
|
|
|
|
|
Automatically inferred from env var `AZURE_OPENAI_AD_TOKEN` if not provided.
|
|
|
|
@ -65,7 +65,7 @@ class AzureOpenAI(BaseOpenAI):
|
|
|
|
|
For more:
|
|
|
|
|
https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id.
|
|
|
|
|
""" # noqa: E501
|
|
|
|
|
azure_ad_token_provider: Union[Callable[[], str], None] = None
|
|
|
|
|
azure_ad_token_provider: Callable[[], str] | None = None
|
|
|
|
|
"""A function that returns an Azure Active Directory token.
|
|
|
|
|
|
|
|
|
|
Will be invoked on every request.
|
|
|
|
@ -78,12 +78,12 @@ class AzureOpenAI(BaseOpenAI):
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def get_lc_namespace(cls) -> List[str]:
|
|
|
|
|
def get_lc_namespace(cls) -> list[str]:
|
|
|
|
|
"""Get the namespace of the langchain object."""
|
|
|
|
|
return ["langchain", "llms", "openai"]
|
|
|
|
|
|
|
|
|
|
@root_validator()
|
|
|
|
|
def validate_environment(cls, values: Dict) -> Dict:
|
|
|
|
|
def validate_environment(cls, values: dict) -> dict:
|
|
|
|
|
"""Validate that api key and python package exists in environment."""
|
|
|
|
|
if values["n"] < 1:
|
|
|
|
|
raise ValueError("n must be at least 1.")
|
|
|
|
@ -206,7 +206,7 @@ class AzureOpenAI(BaseOpenAI):
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def _invocation_params(self) -> Dict[str, Any]:
|
|
|
|
|
def _invocation_params(self) -> dict[str, Any]:
|
|
|
|
|
openai_params = {"model": self.deployment_name}
|
|
|
|
|
return {**openai_params, **super()._invocation_params}
|
|
|
|
|
|
|
|
|
@ -216,7 +216,7 @@ class AzureOpenAI(BaseOpenAI):
|
|
|
|
|
return "azure"
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def lc_attributes(self) -> Dict[str, Any]:
|
|
|
|
|
def lc_attributes(self) -> dict[str, Any]:
|
|
|
|
|
return {
|
|
|
|
|
"openai_api_type": self.openai_api_type,
|
|
|
|
|
"openai_api_version": self.openai_api_version,
|
|
|
|
|