support vLLM async

pull/570/head
Richard Anthony Hein 9 months ago
parent ccea98d387
commit ce80c085a5

@ -1,17 +1,21 @@
from langchain_community.chat_models.azure_openai import (
from langchain_openai.chat_models.azure import (
AzureChatOpenAI,
)
from langchain_community.chat_models.openai import (
from langchain_openai.chat_models import (
ChatOpenAI as OpenAIChat,
)
from langchain.llms.anthropic import Anthropic
from langchain.llms.cohere import Cohere
from langchain.llms.mosaicml import MosaicML
from langchain.llms.openai import OpenAI # , OpenAIChat, AzureOpenAI
from langchain_community.llms.octoai_endpoint import OctoAIEndpoint
from langchain.llms.replicate import Replicate
from langchain_community.llms.fireworks import Fireworks # noqa: F401
from pydantic import model_validator
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.utils import random_uuid
from langchain_community.llms import Anthropic, Cohere, MosaicML, OpenAI, Replicate
from langchain_fireworks import Fireworks
from langchain.schema.output import ChatGeneration, ChatGenerationChunk, ChatResult, GenerationChunk
from langchain.schema.messages import AIMessage, AIMessageChunk, BaseMessage
from langchain.callbacks.manager import AsyncCallbackManagerForLLMRun
from typing import Any, AsyncIterator, Dict, List, Optional
class Anthropic(Anthropic):
def __call__(self, *args, **kwargs):
@ -61,6 +65,19 @@ class AzureOpenAILLM(AzureChatOpenAI):
return self.invoke(*args, **kwargs)
# class OpenAIChatLLM(OpenAIChat):
# def __init__(self, *args, **kwargs):
# super().__init__(*args, **kwargs)
# def __call__(self, *args, **kwargs):
# out = self.invoke(*args, **kwargs)
# return out.content.strip()
# def run(self, *args, **kwargs):
# out = self.invoke(*args, **kwargs)
# return out.content.strip()
class OpenAIChatLLM(OpenAIChat):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@ -73,6 +90,115 @@ class OpenAIChatLLM(OpenAIChat):
out = self.invoke(*args, **kwargs)
return out.content.strip()
# @model_validator(mode='after')
# def validate_environment(cls, values: Dict) -> Dict:
# """Validate that python package exists in environment."""
# from vllm import AsyncEngineArgs
# values["client"] = AsyncLLMEngine.from_engine_args(
# engine_args=AsyncEngineArgs(
# model=values["model"],
# trust_remote_code=True,
# download_dir=values["download_dir"],
# max_model_len=values["vllm_kwargs"]["max_model_len"],
# seed=values["vllm_kwargs"]["seed"],
# ),
# )
# return values
async def _agenerate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
"""Run the LLM on the given prompt and input."""
# build sampling parameters
params = {**self._default_params, **kwargs, "stop": stop}
sampling_params = SamplingParams(**params)
# call the model
client = self.client # type: AsyncLLMEngine
# generations: List[ChatGeneration] = []
for prompt in prompts:
output: RequestOutput
async for output in client.generate(
prompt=prompt, sampling_params=sampling_params, request_id=random_uuid()
):
text = output.outputs[0].text
output: CompletionOutput = output.outputs[0]
# generation_info = output.__dict__
# generations.append(
# ChatGenerationChunk(
# message=AIMessage(content=text),
# generation_info=generation_info,
# )
# )
if output:
yield ChatGenerationChunk(
message=AIMessageChunk(content=output.outputs[0].text),
generation_info=output.outputs[0].generation_info,
)
text = output.outputs[0].text
# generation_info = output.outputs[0].generation_info
if run_manager:
await run_manager.on_llm_new_token(text, verbose=self.verbose)
async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
"""Stream text generation asynchronously.
Args:
prompt: The prompt to pass into the model.
stop: Optional list of stop words to use when generating.
Yields:
GenerationChunk: Generated text chunks.
"""
prompt = self._format_messages_as_text(messages)
# build sampling parameters
params = {**self._default_params, **kwargs, "stop": stop}
sampling_params = SamplingParams(**params)
# call the model
client = self.client # type: AsyncLLMEngine
async for output in client.generate(prompt, sampling_params):
if output:
yield ChatGenerationChunk(
message=AIMessageChunk(content=output.outputs[0].text),
generation_info=output.outputs[0].generation_info,
)
text = output.outputs[0].text
if run_manager:
await run_manager.on_llm_new_token(text, verbose=self.verbose)
def _format_message_as_text(self, message: BaseMessage) -> str:
if isinstance(message, ChatMessage):
message_text = f"\n\n{message.role.capitalize()}: {message.content}"
elif isinstance(message, HumanMessage):
message_text = f"[INST] {message.content} [/INST]"
elif isinstance(message, AIMessage):
message_text = f"{message.content}"
elif isinstance(message, SystemMessage):
message_text = f"<<SYS>> {message.content} <</SYS>>"
else:
raise ValueError(f"Got unknown type {message}")
return message_text
def _format_messages_as_text(self, messages: List[BaseMessage]) -> str:
return "\n".join(
[self._format_message_as_text(message) for message in messages]
)
class OctoAIChat(OctoAIEndpoint):
def __call__(self, *args, **kwargs):

Loading…
Cancel
Save