|
|
@ -1,5 +1,7 @@
|
|
|
|
from __future__ import annotations
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import asyncio
|
|
|
|
|
|
|
|
import functools
|
|
|
|
import logging
|
|
|
|
import logging
|
|
|
|
import sys
|
|
|
|
import sys
|
|
|
|
from typing import (
|
|
|
|
from typing import (
|
|
|
@ -16,6 +18,7 @@ from typing import (
|
|
|
|
Optional,
|
|
|
|
Optional,
|
|
|
|
Set,
|
|
|
|
Set,
|
|
|
|
Tuple,
|
|
|
|
Tuple,
|
|
|
|
|
|
|
|
Type,
|
|
|
|
Union,
|
|
|
|
Union,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
@ -23,7 +26,7 @@ from langchain.callbacks.manager import (
|
|
|
|
AsyncCallbackManagerForLLMRun,
|
|
|
|
AsyncCallbackManagerForLLMRun,
|
|
|
|
CallbackManagerForLLMRun,
|
|
|
|
CallbackManagerForLLMRun,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
from langchain.llms.base import BaseLLM, create_base_retry_decorator
|
|
|
|
from langchain.llms.base import BaseLLM
|
|
|
|
from langchain.pydantic_v1 import Field, root_validator
|
|
|
|
from langchain.pydantic_v1 import Field, root_validator
|
|
|
|
from langchain.schema import Generation, LLMResult
|
|
|
|
from langchain.schema import Generation, LLMResult
|
|
|
|
from langchain.schema.output import GenerationChunk
|
|
|
|
from langchain.schema.output import GenerationChunk
|
|
|
@ -32,7 +35,17 @@ from langchain.utils import (
|
|
|
|
get_pydantic_field_names,
|
|
|
|
get_pydantic_field_names,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
from langchain.utils.utils import build_extra_kwargs
|
|
|
|
from langchain.utils.utils import build_extra_kwargs
|
|
|
|
|
|
|
|
from tenacity import (
|
|
|
|
|
|
|
|
RetryCallState,
|
|
|
|
|
|
|
|
before_sleep_log,
|
|
|
|
|
|
|
|
retry,
|
|
|
|
|
|
|
|
retry_base,
|
|
|
|
|
|
|
|
retry_if_exception_type,
|
|
|
|
|
|
|
|
stop_after_attempt,
|
|
|
|
|
|
|
|
wait_exponential,
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
from importlib.metadata import version
|
|
|
|
from importlib.metadata import version
|
|
|
|
|
|
|
|
|
|
|
@ -41,6 +54,62 @@ from packaging.version import parse
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@functools.lru_cache
|
|
|
|
|
|
|
|
def _log_error_once(msg: str) -> None:
|
|
|
|
|
|
|
|
"""Log an error once."""
|
|
|
|
|
|
|
|
logger.error(msg)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_base_retry_decorator(
|
|
|
|
|
|
|
|
error_types: List[Type[BaseException]],
|
|
|
|
|
|
|
|
max_retries: int = 1,
|
|
|
|
|
|
|
|
run_manager: Optional[
|
|
|
|
|
|
|
|
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
|
|
|
|
|
|
|
|
] = None,
|
|
|
|
|
|
|
|
) -> Callable[[Any], Any]:
|
|
|
|
|
|
|
|
"""Create a retry decorator for a given LLM and provided list of error types."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_logging = before_sleep_log(logger, logging.WARNING)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _before_sleep(retry_state: RetryCallState) -> None:
|
|
|
|
|
|
|
|
_logging(retry_state)
|
|
|
|
|
|
|
|
if run_manager:
|
|
|
|
|
|
|
|
if isinstance(run_manager, AsyncCallbackManagerForLLMRun):
|
|
|
|
|
|
|
|
coro = run_manager.on_retry(retry_state)
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
|
|
loop = asyncio.get_event_loop()
|
|
|
|
|
|
|
|
if loop.is_running():
|
|
|
|
|
|
|
|
loop.create_task(coro)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
asyncio.run(coro)
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
|
|
_log_error_once(f"Error in on_retry: {e}")
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
run_manager.on_retry(retry_state)
|
|
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
min_seconds = 4
|
|
|
|
|
|
|
|
max_seconds = 10
|
|
|
|
|
|
|
|
# Wait 2^x * 1 second between each retry starting with
|
|
|
|
|
|
|
|
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
|
|
|
|
|
|
|
|
retry_instance: "retry_base" = retry_if_exception_type(
|
|
|
|
|
|
|
|
error_types[0]
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
for error in error_types[1:]:
|
|
|
|
|
|
|
|
retry_instance = retry_instance | retry_if_exception_type(
|
|
|
|
|
|
|
|
error
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
return retry(
|
|
|
|
|
|
|
|
reraise=True,
|
|
|
|
|
|
|
|
stop=stop_after_attempt(max_retries),
|
|
|
|
|
|
|
|
wait=wait_exponential(
|
|
|
|
|
|
|
|
multiplier=1, min=min_seconds, max=max_seconds
|
|
|
|
|
|
|
|
),
|
|
|
|
|
|
|
|
retry=retry_instance,
|
|
|
|
|
|
|
|
before_sleep=_before_sleep,
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def is_openai_v1() -> bool:
|
|
|
|
def is_openai_v1() -> bool:
|
|
|
|
_version = parse(version("openai"))
|
|
|
|
_version = parse(version("openai"))
|
|
|
|
return _version.major >= 1
|
|
|
|
return _version.major >= 1
|
|
|
@ -833,7 +902,7 @@ class OpenAIChat(BaseLLM):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
client: Any #: :meta private:
|
|
|
|
client: Any #: :meta private:
|
|
|
|
model_name: str = "gpt-3.5-turbo-1106"
|
|
|
|
model_name: str = "gpt-4-1106-preview"
|
|
|
|
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
|
|
|
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
|
|
|
openai_api_key: Optional[str] = None
|
|
|
|
openai_api_key: Optional[str] = None
|
|
|
|
openai_api_base: Optional[str] = None
|
|
|
|
openai_api_base: Optional[str] = None
|
|
|
|