[CHORES][VLLM]

pull/286/head
Kye 1 year ago
parent 154b8cccde
commit 30904e244e

@ -80,6 +80,7 @@ nav:
- OpenAI: "swarms/models/openai.md" - OpenAI: "swarms/models/openai.md"
- Zephyr: "swarms/models/zephyr.md" - Zephyr: "swarms/models/zephyr.md"
- BioGPT: "swarms/models/biogpt.md" - BioGPT: "swarms/models/biogpt.md"
- vLLM: "swarms/models/vllm.md"
- MPT7B: "swarms/models/mpt.md" - MPT7B: "swarms/models/mpt.md"
- Mistral: "swarms/models/mistral.md" - Mistral: "swarms/models/mistral.md"
- MultiModal: - MultiModal:

@ -7,7 +7,7 @@ custom_vllm = vLLM(
trust_remote_code=True, trust_remote_code=True,
revision="abc123", revision="abc123",
temperature=0.7, temperature=0.7,
top_p=0.8 top_p=0.8,
) )
# Generate text with custom configuration # Generate text with custom configuration

@ -12,3 +12,20 @@ weaviate_client = WeaviateClient(
additional_config=None, # You can pass additional configuration here additional_config=None, # You can pass additional configuration here
) )
weaviate_client.create_collection(
name="my_collection",
properties=[
{"name": "property1", "dataType": ["string"]},
{"name": "property2", "dataType": ["int"]},
],
vectorizer_config=None, # Optional vectorizer configuration
)
weaviate_client.add(
collection_name="my_collection",
properties={"property1": "value1", "property2": 42},
)
results = weaviate_client.query(
collection_name="people", query="name:John", limit=5
)

@ -1,7 +1,8 @@
# from swarms.memory.pinecone import PineconeVector try:
# from swarms.memory.base import BaseVectorStore from swarms.memory.weaviate import WeaviateClient
# from swarms.memory.pg import PgVectorVectorStore except ImportError:
from swarms.memory.weaviate import WeaviateClient pass
from swarms.memory.base_vectordb import VectorDatabase from swarms.memory.base_vectordb import VectorDatabase
__all__ = [ __all__ = [

@ -1,130 +0,0 @@
from abc import ABC, abstractmethod
from concurrent import futures
from dataclasses import dataclass
from typing import Optional, Any
from attr import define, field, Factory
from swarms.utils.execute_futures import execute_futures_dict
from griptape.artifacts import TextArtifact
@define
class BaseVectorStore(ABC):
DEFAULT_QUERY_COUNT = 5
@dataclass
class QueryResult:
id: str
vector: list[float]
score: float
meta: Optional[dict] = None
namespace: Optional[str] = None
@dataclass
class Entry:
id: str
vector: list[float]
meta: Optional[dict] = None
namespace: Optional[str] = None
embedding_driver: Any
futures_executor: futures.Executor = field(
default=Factory(lambda: futures.ThreadPoolExecutor()),
kw_only=True,
)
def upsert_text_artifacts(
self,
artifacts: dict[str, list[TextArtifact]],
meta: Optional[dict] = None,
**kwargs,
) -> None:
execute_futures_dict(
{
namespace: self.futures_executor.submit(
self.upsert_text_artifact,
a,
namespace,
meta,
**kwargs,
)
for namespace, artifact_list in artifacts.items()
for a in artifact_list
}
)
def upsert_text_artifact(
self,
artifact: TextArtifact,
namespace: Optional[str] = None,
meta: Optional[dict] = None,
**kwargs,
) -> str:
if not meta:
meta = {}
meta["artifact"] = artifact.to_json()
if artifact.embedding:
vector = artifact.embedding
else:
vector = artifact.generate_embedding(
self.embedding_driver
)
return self.upsert_vector(
vector,
vector_id=artifact.id,
namespace=namespace,
meta=meta,
**kwargs,
)
def upsert_text(
self,
string: str,
vector_id: Optional[str] = None,
namespace: Optional[str] = None,
meta: Optional[dict] = None,
**kwargs,
) -> str:
return self.upsert_vector(
self.embedding_driver.embed_string(string),
vector_id=vector_id,
namespace=namespace,
meta=meta if meta else {},
**kwargs,
)
@abstractmethod
def upsert_vector(
self,
vector: list[float],
vector_id: Optional[str] = None,
namespace: Optional[str] = None,
meta: Optional[dict] = None,
**kwargs,
) -> str:
...
@abstractmethod
def load_entry(
self, vector_id: str, namespace: Optional[str] = None
) -> Entry:
...
@abstractmethod
def load_entries(
self, namespace: Optional[str] = None
) -> list[Entry]:
...
@abstractmethod
def query(
self,
query: str,
count: Optional[int] = None,
namespace: Optional[str] = None,
include_vectors: bool = False,
**kwargs,
) -> list[QueryResult]:
...

@ -1,12 +1,14 @@
import os
from termcolor import colored
import logging import logging
import os
from typing import Dict, List, Optional from typing import Dict, List, Optional
import chromadb import chromadb
import tiktoken as tiktoken import tiktoken as tiktoken
from chromadb.config import Settings from chromadb.config import Settings
from chromadb.utils.embedding_functions import OpenAIEmbeddingFunction from chromadb.utils.embedding_functions import OpenAIEmbeddingFunction
from dotenv import load_dotenv from dotenv import load_dotenv
from termcolor import colored
from swarms.utils.token_count_tiktoken import limit_tokens_from_string from swarms.utils.token_count_tiktoken import limit_tokens_from_string
load_dotenv() load_dotenv()

@ -8,7 +8,12 @@ from swarms.models.openai_models import (
AzureOpenAI, AzureOpenAI,
OpenAIChat, OpenAIChat,
) # noqa: E402 ) # noqa: E402
from swarms.models.vllm import vLLM # noqa: E402
try:
from swarms.models.vllm import vLLM # noqa: E402
except ImportError:
pass
# from swarms.models.zephyr import Zephyr # noqa: E402 # from swarms.models.zephyr import Zephyr # noqa: E402
from swarms.models.biogpt import BioGPT # noqa: E402 from swarms.models.biogpt import BioGPT # noqa: E402
from swarms.models.huggingface import HuggingfaceLLM # noqa: E402 from swarms.models.huggingface import HuggingfaceLLM # noqa: E402
@ -59,4 +64,5 @@ __all__ = [
# "Dalle3", # "Dalle3",
# "DistilWhisperModel", # "DistilWhisperModel",
"GPT4VisionAPI", "GPT4VisionAPI",
"vLLM",
] ]

Loading…
Cancel
Save