diff --git a/swarms/agents/models/openai.py b/swarms/agents/models/openai.py index 6958cd1a..80bcbcbe 100644 --- a/swarms/agents/models/openai.py +++ b/swarms/agents/models/openai.py @@ -1,4 +1,9 @@ -from simpleaichat import AIChat +#kye +#aug 8, 11:51 + +from simpleaichat import AIChat, AsyncAIChat +import asyncio + class OpenAI: def __init__(self, @@ -10,12 +15,27 @@ class OpenAI: save_messages=True): self.api_key = api_key or self._fetch_api_key() self.system = system or "You are a helpful assistant" - self.ai = AIChat(api_key=self.api_key, - system=self.system, - console=console, - model=model, - params=params, - save_messages=save_messages) + + try: + + self.ai = AIChat(api_key=self.api_key, + system=self.system, + console=console, + model=model, + params=params, + save_messages=save_messages) + + self.async_ai = AsyncAIChat( + api_key=self.api_key, + system=self.system, + console=console, + model=model, + params=params, + save_messages=save_messages + ) + + except Exception as error: + raise ValueError(f"Failed to initialize the chat with error: {error}, check inputs and input types") def __call__(self, message, **kwargs): try: @@ -28,6 +48,41 @@ class OpenAI: return self.ai(message, **kwargs) except Exception as error: print(f"Error in OpenAI, {error}") + + async def generate_async(self, message, **kwargs): + try: + return await self.async_ai(message, **kwargs) + except Exception as error: + raise Exception(f"Error in asynchronous OpenAI Call, {error}") + + def initialize_chat(self, ids): + for id in ids: + try: + self.async_ai.new_session(api_key=self.api_key, id=id) + except Exception as error: + raise ValueError(f"Failed to initialize session for ID {id} with error: {error}") + + async def ask_multiple(self, ids, question_template): + try: + self.initialize_chat(ids) + tasks = [self.async_ai(question_template.format(id=id), id=id) for id in ids] + return await asyncio.gather(*tasks) + except Exception as error: + raise Exception(f"Error in ask_multiple: method: {error}") + + async def stream_multiple(self, ids, question_template): + try: + self.initialize_chat(ids) + + async def stream_id(id): + async for chunk in await self.async_ai.stream(question_template.format(id=id), id=id): + response = chunk["response"] + return response + + tasks = [stream_id(id) for id in ids] + return await asyncio.gather(*tasks) + except Exception as error: + raise Exception(f"Error in stream_multiple method: {error}") def fetch_api_key(self): pass @@ -39,3 +94,7 @@ class OpenAI: #response = chat.generate("Hello World") #print(response) +#async +# async_responses = asyncio.run(chat.ask_multiple(['id1', 'id2'], "How is {id}")) +# print(async_responses) +# \ No newline at end of file