black reformat

pull/385/head
evelynmitchell 1 year ago
parent 3c7e2b1df0
commit 9112017234

@ -1,6 +1,7 @@
"""Example of using the swarms package to run a workflow."""
from swarms import Agent, OpenAIChat from swarms import Agent, OpenAIChat
## Initialize the workflow # Initialize the workflow
agent = Agent( agent = Agent(
llm=OpenAIChat(), llm=OpenAIChat(),
max_loops="auto", max_loops="auto",

@ -25,7 +25,9 @@ from langchain_community.callbacks.manager import (
) )
from langchain_community.llms.base import LLM from langchain_community.llms.base import LLM
from pydantic import Field, SecretStr, root_validator from pydantic import Field, SecretStr, root_validator
from langchain_community.schema.language_model import BaseLanguageModel from langchain_community.schema.language_model import (
BaseLanguageModel,
)
from langchain_community.schema.output import GenerationChunk from langchain_community.schema.output import GenerationChunk
from langchain_community.schema.prompt import PromptValue from langchain_community.schema.prompt import PromptValue
from langchain_community.utils import ( from langchain_community.utils import (

@ -88,9 +88,10 @@ class Kosmos(BaseMultiModalModel):
skip_special_tokens=True, skip_special_tokens=True,
)[0] )[0]
processed_text, entities = ( (
self.processor.post_process_generation(generated_texts) processed_text,
) entities,
) = self.processor.post_process_generation(generated_texts)
return processed_text, entities return processed_text, entities

@ -115,13 +115,14 @@ class MedicalSAM:
if len(box_torch.shape) == 2: if len(box_torch.shape) == 2:
box_torch = box_torch[:, None, :] box_torch = box_torch[:, None, :]
sparse_embeddings, dense_embeddings = ( (
self.model.prompt_encoder( sparse_embeddings,
dense_embeddings,
) = self.model.prompt_encoder(
points=None, points=None,
boxes=box_torch, boxes=box_torch,
masks=None, masks=None,
) )
)
low_res_logits, _ = self.model.mask_decoder( low_res_logits, _ = self.model.mask_decoder(
image_embeddings=img, image_embeddings=img,

@ -208,9 +208,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
"""Maximum number of texts to embed in each batch""" """Maximum number of texts to embed in each batch"""
max_retries: int = 6 max_retries: int = 6
"""Maximum number of retries to make when generating.""" """Maximum number of retries to make when generating."""
request_timeout: Optional[Union[float, Tuple[float, float]]] = ( request_timeout: Optional[
None Union[float, Tuple[float, float]]
) ] = None
"""Timeout in seconds for the OpenAPI request.""" """Timeout in seconds for the OpenAPI request."""
headers: Any = None headers: Any = None
tiktoken_model_name: Optional[str] = None tiktoken_model_name: Optional[str] = None

@ -244,9 +244,9 @@ class BaseOpenAI(BaseLLM):
attributes["openai_api_base"] = self.openai_api_base attributes["openai_api_base"] = self.openai_api_base
if self.openai_organization != "": if self.openai_organization != "":
attributes["openai_organization"] = ( attributes[
self.openai_organization "openai_organization"
) ] = self.openai_organization
if self.openai_proxy != "": if self.openai_proxy != "":
attributes["openai_proxy"] = self.openai_proxy attributes["openai_proxy"] = self.openai_proxy
@ -287,9 +287,9 @@ class BaseOpenAI(BaseLLM):
openai_proxy: Optional[str] = None openai_proxy: Optional[str] = None
batch_size: int = 20 batch_size: int = 20
"""Batch size to use when passing multiple documents to generate.""" """Batch size to use when passing multiple documents to generate."""
request_timeout: Optional[Union[float, Tuple[float, float]]] = ( request_timeout: Optional[
None Union[float, Tuple[float, float]]
) ] = None
"""Timeout for requests to OpenAI completion API. Default is 600 seconds.""" """Timeout for requests to OpenAI completion API. Default is 600 seconds."""
logit_bias: Optional[Dict[str, float]] = Field( logit_bias: Optional[Dict[str, float]] = Field(
default_factory=dict default_factory=dict

@ -3,7 +3,9 @@ from __future__ import annotations
import logging import logging
from typing import Any, Callable, Dict, List, Optional from typing import Any, Callable, Dict, List, Optional
from langchain_community.callbacks.manager import CallbackManagerForLLMRun from langchain_community.callbacks.manager import (
CallbackManagerForLLMRun,
)
from langchain_community.llms import BaseLLM from langchain_community.llms import BaseLLM
from langchain_community.pydantic_v1 import BaseModel, root_validator from langchain_community.pydantic_v1 import BaseModel, root_validator
from langchain_community.schema import Generation, LLMResult from langchain_community.schema import Generation, LLMResult

@ -86,7 +86,9 @@ class BaseMessage(Serializable):
return True return True
def __add__(self, other: Any) -> ChatPromptTemplate: def __add__(self, other: Any) -> ChatPromptTemplate:
from langchain_community.prompts.chat import ChatPromptTemplate from langchain_community.prompts.chat import (
ChatPromptTemplate,
)
prompt = ChatPromptTemplate(messages=[self]) prompt = ChatPromptTemplate(messages=[self])
return prompt + other return prompt + other

@ -62,6 +62,8 @@ def worker_tools_sop_promp(name: str, memory: str):
[{memory}] [{memory}]
Human: Determine which next command to use, and respond using the format specified above: Human: Determine which next command to use, and respond using the format specified above:
""".format(name=name, memory=memory, time=time) """.format(
name=name, memory=memory, time=time
)
return str(out) return str(out)

@ -36,9 +36,9 @@ class ConcurrentWorkflow(BaseStructure):
max_loops: int = 1 max_loops: int = 1
max_workers: int = 5 max_workers: int = 5
autosave: bool = False autosave: bool = False
saved_state_filepath: Optional[str] = ( saved_state_filepath: Optional[
"runs/concurrent_workflow.json" str
) ] = "runs/concurrent_workflow.json"
print_results: bool = False print_results: bool = False
return_results: bool = False return_results: bool = False
use_processes: bool = False use_processes: bool = False

@ -317,9 +317,9 @@ class MultiAgentCollaboration:
"""Tracks and reports the performance of each agent""" """Tracks and reports the performance of each agent"""
performance_data = {} performance_data = {}
for agent in self.agents: for agent in self.agents:
performance_data[agent.name] = ( performance_data[
agent.get_performance_metrics() agent.name
) ] = agent.get_performance_metrics()
return performance_data return performance_data
def set_interaction_rules(self, rules): def set_interaction_rules(self, rules):

@ -42,9 +42,9 @@ class SequentialWorkflow:
task_pool: List[Task] = field(default_factory=list) task_pool: List[Task] = field(default_factory=list)
max_loops: int = 1 max_loops: int = 1
autosave: bool = False autosave: bool = False
saved_state_filepath: Optional[str] = ( saved_state_filepath: Optional[
"sequential_workflow_state.json" str
) ] = "sequential_workflow_state.json"
restore_state_filepath: Optional[str] = None restore_state_filepath: Optional[str] = None
dashboard: bool = False dashboard: bool = False

@ -84,7 +84,9 @@ class RedisSwarmRegistry(AbstractSwarm):
query = f""" query = f"""
{match_query} {match_query}
CREATE (a)-[r:joined]->(b) RETURN r CREATE (a)-[r:joined]->(b) RETURN r
""".replace("\n", "") """.replace(
"\n", ""
)
self.redis_graph.query(query) self.redis_graph.query(query)

@ -902,9 +902,9 @@ def tool(
coroutine = ainvoke_wrapper coroutine = ainvoke_wrapper
func = invoke_wrapper func = invoke_wrapper
schema: Optional[Type[BaseModel]] = ( schema: Optional[
runnable.input_schema Type[BaseModel]
) ] = runnable.input_schema
description = repr(runnable) description = repr(runnable)
elif inspect.iscoroutinefunction(dec_func): elif inspect.iscoroutinefunction(dec_func):
coroutine = dec_func coroutine = dec_func

@ -8,4 +8,3 @@ def load_environment():
api_key = os.environ.get("OPENAI_API_KEY") api_key = os.environ.get("OPENAI_API_KEY")
return api_key, os.environ return api_key, os.environ

Loading…
Cancel
Save