diff --git a/swarms/models/popular_llms.py b/swarms/models/popular_llms.py index b600b4c6..fd9a0192 100644 --- a/swarms/models/popular_llms.py +++ b/swarms/models/popular_llms.py @@ -1,17 +1,21 @@ -from langchain_community.chat_models.azure_openai import ( +from langchain_openai.chat_models.azure import ( AzureChatOpenAI, ) -from langchain_community.chat_models.openai import ( +from langchain_openai.chat_models import ( ChatOpenAI as OpenAIChat, ) -from langchain.llms.anthropic import Anthropic -from langchain.llms.cohere import Cohere -from langchain.llms.mosaicml import MosaicML -from langchain.llms.openai import OpenAI # , OpenAIChat, AzureOpenAI from langchain_community.llms.octoai_endpoint import OctoAIEndpoint -from langchain.llms.replicate import Replicate -from langchain_community.llms.fireworks import Fireworks # noqa: F401 - +from pydantic import model_validator +from vllm.outputs import CompletionOutput, RequestOutput +from vllm.sampling_params import SamplingParams +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.utils import random_uuid +from langchain_community.llms import Anthropic, Cohere, MosaicML, OpenAI, Replicate +from langchain_fireworks import Fireworks +from langchain.schema.output import ChatGeneration, ChatGenerationChunk, ChatResult, GenerationChunk +from langchain.schema.messages import AIMessage, AIMessageChunk, BaseMessage +from langchain.callbacks.manager import AsyncCallbackManagerForLLMRun +from typing import Any, AsyncIterator, Dict, List, Optional class Anthropic(Anthropic): def __call__(self, *args, **kwargs): @@ -61,6 +65,19 @@ class AzureOpenAILLM(AzureChatOpenAI): return self.invoke(*args, **kwargs) +# class OpenAIChatLLM(OpenAIChat): +# def __init__(self, *args, **kwargs): +# super().__init__(*args, **kwargs) + +# def __call__(self, *args, **kwargs): +# out = self.invoke(*args, **kwargs) +# return out.content.strip() + +# def run(self, *args, **kwargs): +# out = self.invoke(*args, **kwargs) +# return out.content.strip() + + class OpenAIChatLLM(OpenAIChat): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -73,6 +90,115 @@ class OpenAIChatLLM(OpenAIChat): out = self.invoke(*args, **kwargs) return out.content.strip() + # @model_validator(mode='after') + # def validate_environment(cls, values: Dict) -> Dict: + # """Validate that python package exists in environment.""" + # from vllm import AsyncEngineArgs + + # values["client"] = AsyncLLMEngine.from_engine_args( + # engine_args=AsyncEngineArgs( + # model=values["model"], + # trust_remote_code=True, + # download_dir=values["download_dir"], + # max_model_len=values["vllm_kwargs"]["max_model_len"], + # seed=values["vllm_kwargs"]["seed"], + # ), + # ) + + # return values + + async def _agenerate( + self, + prompts: List[str], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> AsyncIterator[ChatGenerationChunk]: + """Run the LLM on the given prompt and input.""" + + # build sampling parameters + params = {**self._default_params, **kwargs, "stop": stop} + sampling_params = SamplingParams(**params) + # call the model + client = self.client # type: AsyncLLMEngine + + # generations: List[ChatGeneration] = [] + for prompt in prompts: + output: RequestOutput + async for output in client.generate( + prompt=prompt, sampling_params=sampling_params, request_id=random_uuid() + ): + text = output.outputs[0].text + output: CompletionOutput = output.outputs[0] + # generation_info = output.__dict__ + # generations.append( + # ChatGenerationChunk( + # message=AIMessage(content=text), + # generation_info=generation_info, + # ) + # ) + if output: + yield ChatGenerationChunk( + message=AIMessageChunk(content=output.outputs[0].text), + generation_info=output.outputs[0].generation_info, + ) + text = output.outputs[0].text + # generation_info = output.outputs[0].generation_info + if run_manager: + await run_manager.on_llm_new_token(text, verbose=self.verbose) + + async def _astream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> AsyncIterator[ChatGenerationChunk]: + """Stream text generation asynchronously. + + Args: + prompt: The prompt to pass into the model. + stop: Optional list of stop words to use when generating. + + Yields: + GenerationChunk: Generated text chunks. + """ + prompt = self._format_messages_as_text(messages) + + # build sampling parameters + params = {**self._default_params, **kwargs, "stop": stop} + sampling_params = SamplingParams(**params) + # call the model + client = self.client # type: AsyncLLMEngine + + async for output in client.generate(prompt, sampling_params): + if output: + yield ChatGenerationChunk( + message=AIMessageChunk(content=output.outputs[0].text), + generation_info=output.outputs[0].generation_info, + ) + text = output.outputs[0].text + if run_manager: + await run_manager.on_llm_new_token(text, verbose=self.verbose) + + def _format_message_as_text(self, message: BaseMessage) -> str: + if isinstance(message, ChatMessage): + message_text = f"\n\n{message.role.capitalize()}: {message.content}" + elif isinstance(message, HumanMessage): + message_text = f"[INST] {message.content} [/INST]" + elif isinstance(message, AIMessage): + message_text = f"{message.content}" + elif isinstance(message, SystemMessage): + message_text = f"<> {message.content} <>" + else: + raise ValueError(f"Got unknown type {message}") + return message_text + + def _format_messages_as_text(self, messages: List[BaseMessage]) -> str: + return "\n".join( + [self._format_message_as_text(message) for message in messages] + ) + class OctoAIChat(OctoAIEndpoint): def __call__(self, *args, **kwargs):