[CHORES][VLLM]

pull/286/head
Kye 1 year ago
parent 154b8cccde
commit 30904e244e

@ -80,6 +80,7 @@ nav:
- OpenAI: "swarms/models/openai.md"
- Zephyr: "swarms/models/zephyr.md"
- BioGPT: "swarms/models/biogpt.md"
- vLLM: "swarms/models/vllm.md"
- MPT7B: "swarms/models/mpt.md"
- Mistral: "swarms/models/mistral.md"
- MultiModal:

@ -7,7 +7,7 @@ custom_vllm = vLLM(
trust_remote_code=True,
revision="abc123",
temperature=0.7,
top_p=0.8
top_p=0.8,
)
# Generate text with custom configuration

@ -12,3 +12,20 @@ weaviate_client = WeaviateClient(
additional_config=None, # You can pass additional configuration here
)
weaviate_client.create_collection(
name="my_collection",
properties=[
{"name": "property1", "dataType": ["string"]},
{"name": "property2", "dataType": ["int"]},
],
vectorizer_config=None, # Optional vectorizer configuration
)
weaviate_client.add(
collection_name="my_collection",
properties={"property1": "value1", "property2": 42},
)
results = weaviate_client.query(
collection_name="people", query="name:John", limit=5
)

@ -1,7 +1,8 @@
# from swarms.memory.pinecone import PineconeVector
# from swarms.memory.base import BaseVectorStore
# from swarms.memory.pg import PgVectorVectorStore
from swarms.memory.weaviate import WeaviateClient
try:
from swarms.memory.weaviate import WeaviateClient
except ImportError:
pass
from swarms.memory.base_vectordb import VectorDatabase
__all__ = [

@ -1,130 +0,0 @@
from abc import ABC, abstractmethod
from concurrent import futures
from dataclasses import dataclass
from typing import Optional, Any
from attr import define, field, Factory
from swarms.utils.execute_futures import execute_futures_dict
from griptape.artifacts import TextArtifact
@define
class BaseVectorStore(ABC):
DEFAULT_QUERY_COUNT = 5
@dataclass
class QueryResult:
id: str
vector: list[float]
score: float
meta: Optional[dict] = None
namespace: Optional[str] = None
@dataclass
class Entry:
id: str
vector: list[float]
meta: Optional[dict] = None
namespace: Optional[str] = None
embedding_driver: Any
futures_executor: futures.Executor = field(
default=Factory(lambda: futures.ThreadPoolExecutor()),
kw_only=True,
)
def upsert_text_artifacts(
self,
artifacts: dict[str, list[TextArtifact]],
meta: Optional[dict] = None,
**kwargs,
) -> None:
execute_futures_dict(
{
namespace: self.futures_executor.submit(
self.upsert_text_artifact,
a,
namespace,
meta,
**kwargs,
)
for namespace, artifact_list in artifacts.items()
for a in artifact_list
}
)
def upsert_text_artifact(
self,
artifact: TextArtifact,
namespace: Optional[str] = None,
meta: Optional[dict] = None,
**kwargs,
) -> str:
if not meta:
meta = {}
meta["artifact"] = artifact.to_json()
if artifact.embedding:
vector = artifact.embedding
else:
vector = artifact.generate_embedding(
self.embedding_driver
)
return self.upsert_vector(
vector,
vector_id=artifact.id,
namespace=namespace,
meta=meta,
**kwargs,
)
def upsert_text(
self,
string: str,
vector_id: Optional[str] = None,
namespace: Optional[str] = None,
meta: Optional[dict] = None,
**kwargs,
) -> str:
return self.upsert_vector(
self.embedding_driver.embed_string(string),
vector_id=vector_id,
namespace=namespace,
meta=meta if meta else {},
**kwargs,
)
@abstractmethod
def upsert_vector(
self,
vector: list[float],
vector_id: Optional[str] = None,
namespace: Optional[str] = None,
meta: Optional[dict] = None,
**kwargs,
) -> str:
...
@abstractmethod
def load_entry(
self, vector_id: str, namespace: Optional[str] = None
) -> Entry:
...
@abstractmethod
def load_entries(
self, namespace: Optional[str] = None
) -> list[Entry]:
...
@abstractmethod
def query(
self,
query: str,
count: Optional[int] = None,
namespace: Optional[str] = None,
include_vectors: bool = False,
**kwargs,
) -> list[QueryResult]:
...

@ -1,12 +1,14 @@
import os
from termcolor import colored
import logging
import os
from typing import Dict, List, Optional
import chromadb
import tiktoken as tiktoken
from chromadb.config import Settings
from chromadb.utils.embedding_functions import OpenAIEmbeddingFunction
from dotenv import load_dotenv
from termcolor import colored
from swarms.utils.token_count_tiktoken import limit_tokens_from_string
load_dotenv()

@ -8,7 +8,12 @@ from swarms.models.openai_models import (
AzureOpenAI,
OpenAIChat,
) # noqa: E402
from swarms.models.vllm import vLLM # noqa: E402
try:
from swarms.models.vllm import vLLM # noqa: E402
except ImportError:
pass
# from swarms.models.zephyr import Zephyr # noqa: E402
from swarms.models.biogpt import BioGPT # noqa: E402
from swarms.models.huggingface import HuggingfaceLLM # noqa: E402
@ -59,4 +64,5 @@ __all__ = [
# "Dalle3",
# "DistilWhisperModel",
"GPT4VisionAPI",
"vLLM",
]

Loading…
Cancel
Save