commit
7958fb01e1
@ -0,0 +1,266 @@
|
|||||||
|
import discord
|
||||||
|
from discord.ext import commands
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from invoke import Executor
|
||||||
|
|
||||||
|
|
||||||
|
class BotCommands(commands.Cog):
|
||||||
|
def __init__(self, bot):
|
||||||
|
self.bot = bot
|
||||||
|
|
||||||
|
@commands.command()
|
||||||
|
async def greet(self, ctx):
|
||||||
|
"""greets the user."""
|
||||||
|
await ctx.send(f"hello, {ctx.author.name}!")
|
||||||
|
|
||||||
|
@commands.command()
|
||||||
|
async def help_me(self, ctx):
|
||||||
|
"""provides a list of commands and their descriptions."""
|
||||||
|
help_text = """
|
||||||
|
- `!greet`: greets you.
|
||||||
|
- `!run [description]`: generates a video based on the given description.
|
||||||
|
- `!help_me`: provides this list of commands and their descriptions.
|
||||||
|
"""
|
||||||
|
await ctx.send(help_text)
|
||||||
|
|
||||||
|
@commands.command()
|
||||||
|
async def join(self, ctx):
|
||||||
|
"""joins the voice channel that the user is in."""
|
||||||
|
if ctx.author.voice:
|
||||||
|
channel = ctx.author.voice.channel
|
||||||
|
await channel.connect()
|
||||||
|
else:
|
||||||
|
await ctx.send("you are not in a voice channel!")
|
||||||
|
|
||||||
|
@commands.command()
|
||||||
|
async def leave(self, ctx):
|
||||||
|
"""leaves the voice channel that the self.bot is in."""
|
||||||
|
if ctx.voice_client:
|
||||||
|
await ctx.voice_client.disconnect()
|
||||||
|
else:
|
||||||
|
await ctx.send("i am not in a voice channel!")
|
||||||
|
|
||||||
|
@commands.command()
|
||||||
|
async def listen(self, ctx):
|
||||||
|
"""starts listening to voice in the voice channel that the bot is in."""
|
||||||
|
if ctx.voice_client:
|
||||||
|
# create a wavesink to record the audio
|
||||||
|
sink = discord.sinks.wavesink("audio.wav")
|
||||||
|
# start recording
|
||||||
|
ctx.voice_client.start_recording(sink)
|
||||||
|
await ctx.send("started listening and recording.")
|
||||||
|
else:
|
||||||
|
await ctx.send("i am not in a voice channel!")
|
||||||
|
|
||||||
|
@commands.command()
|
||||||
|
async def generate_image(self, ctx, *, prompt: str = None, imggen: str = None):
|
||||||
|
"""generates images based on the provided prompt"""
|
||||||
|
await ctx.send(f"generating images for prompt: `{prompt}`...")
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
|
# initialize a future object for the dalle instance
|
||||||
|
future = loop.run_in_executor(Executor, imggen, prompt)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# wait for the dalle request to complete, with a timeout of 60 seconds
|
||||||
|
await asyncio.wait_for(future, timeout=300)
|
||||||
|
print("done generating images!")
|
||||||
|
|
||||||
|
# list all files in the save_directory
|
||||||
|
all_files = [
|
||||||
|
os.path.join(root, file)
|
||||||
|
for root, _, files in os.walk(os.environ("SAVE_DIRECTORY"))
|
||||||
|
for file in files
|
||||||
|
]
|
||||||
|
|
||||||
|
# sort files by their creation time (latest first)
|
||||||
|
sorted_files = sorted(all_files, key=os.path.getctime, reverse=True)
|
||||||
|
|
||||||
|
# get the 4 most recent files
|
||||||
|
latest_files = sorted_files[:4]
|
||||||
|
print(f"sending {len(latest_files)} images to discord...")
|
||||||
|
|
||||||
|
# send all the latest images in a single message
|
||||||
|
# storage_service = os.environ("STORAGE_SERVICE") # "https://storage.googleapis.com/your-bucket-name/
|
||||||
|
# await ctx.send(files=[storage_service.upload(filepath) for filepath in latest_files])
|
||||||
|
|
||||||
|
except asyncio.timeouterror:
|
||||||
|
await ctx.send(
|
||||||
|
"the request took too long! it might have been censored or you're out of boosts. please try entering the prompt again."
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
await ctx.send(f"an error occurred: {e}")
|
||||||
|
|
||||||
|
@commands.command()
|
||||||
|
async def send_text(self, ctx, *, text: str, use_agent: bool = True):
|
||||||
|
"""sends the provided text to the worker and returns the response"""
|
||||||
|
if use_agent:
|
||||||
|
response = self.bot.agent.run(text)
|
||||||
|
else:
|
||||||
|
response = self.bot.llm(text)
|
||||||
|
await ctx.send(response)
|
||||||
|
|
||||||
|
@commands.Cog.listener()
|
||||||
|
async def on_ready(self):
|
||||||
|
print(f"we have logged in as {self.bot.user}")
|
||||||
|
|
||||||
|
@commands.Cog.listener()
|
||||||
|
async def on_command_error(self, ctx, error):
|
||||||
|
"""handles errors that occur while executing commands."""
|
||||||
|
if isinstance(error, commands.CommandNotFound):
|
||||||
|
await ctx.send("that command does not exist!")
|
||||||
|
else:
|
||||||
|
await ctx.send(f"an error occurred: {error}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class Bot:
|
||||||
|
def __init__(self, llm, command_prefix="!"):
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
intents = discord.Intents.default()
|
||||||
|
intents.messages = True
|
||||||
|
intents.guilds = True
|
||||||
|
intents.voice_states = True
|
||||||
|
intents.message_content = True
|
||||||
|
|
||||||
|
# setup
|
||||||
|
self.llm = llm
|
||||||
|
self.bot = commands.Bot(command_prefix="!", intents=intents)
|
||||||
|
self.discord_token = os.getenv("DISCORD_TOKEN")
|
||||||
|
self.storage_service = os.getenv("STORAGE_SERVICE")
|
||||||
|
|
||||||
|
# Load the BotCommands cog
|
||||||
|
self.bot.add_cog(BotCommands(self.bot))
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
self.bot.run(self.discord_token)
|
||||||
|
self.agent = agent
|
||||||
|
self.bot = commands.bot(command_prefix="!", intents=intents)
|
||||||
|
self.discord_token = os.getenv("DISCORD_TOKEN")
|
||||||
|
self.storage_service = os.getenv("STORAGE_SERVICE")
|
||||||
|
|
||||||
|
@self.bot.event
|
||||||
|
async def on_ready():
|
||||||
|
print(f"we have logged in as {self.bot.user}")
|
||||||
|
|
||||||
|
@self.bot.command()
|
||||||
|
async def greet(ctx):
|
||||||
|
"""greets the user."""
|
||||||
|
await ctx.send(f"hello, {ctx.author.name}!")
|
||||||
|
|
||||||
|
@self.bot.command()
|
||||||
|
async def help_me(ctx):
|
||||||
|
"""provides a list of commands and their descriptions."""
|
||||||
|
help_text = """
|
||||||
|
- `!greet`: greets you.
|
||||||
|
- `!run [description]`: generates a video based on the given description.
|
||||||
|
- `!help_me`: provides this list of commands and their descriptions.
|
||||||
|
"""
|
||||||
|
await ctx.send(help_text)
|
||||||
|
|
||||||
|
@self.bot.event
|
||||||
|
async def on_command_error(ctx, error):
|
||||||
|
"""handles errors that occur while executing commands."""
|
||||||
|
if isinstance(error, commands.commandnotfound):
|
||||||
|
await ctx.send("that command does not exist!")
|
||||||
|
else:
|
||||||
|
await ctx.send(f"an error occurred: {error}")
|
||||||
|
|
||||||
|
@self.bot.command()
|
||||||
|
async def join(ctx):
|
||||||
|
"""joins the voice channel that the user is in."""
|
||||||
|
if ctx.author.voice:
|
||||||
|
channel = ctx.author.voice.channel
|
||||||
|
await channel.connect()
|
||||||
|
else:
|
||||||
|
await ctx.send("you are not in a voice channel!")
|
||||||
|
|
||||||
|
@self.bot.command()
|
||||||
|
async def leave(ctx):
|
||||||
|
"""leaves the voice channel that the self.bot is in."""
|
||||||
|
if ctx.voice_client:
|
||||||
|
await ctx.voice_client.disconnect()
|
||||||
|
else:
|
||||||
|
await ctx.send("i am not in a voice channel!")
|
||||||
|
|
||||||
|
# voice_transcription.py
|
||||||
|
@self.bot.command()
|
||||||
|
async def listen(ctx):
|
||||||
|
"""starts listening to voice in the voice channel that the bot is in."""
|
||||||
|
if ctx.voice_client:
|
||||||
|
# create a wavesink to record the audio
|
||||||
|
sink = discord.sinks.wavesink("audio.wav")
|
||||||
|
# start recording
|
||||||
|
ctx.voice_client.start_recording(sink)
|
||||||
|
await ctx.send("started listening and recording.")
|
||||||
|
else:
|
||||||
|
await ctx.send("i am not in a voice channel!")
|
||||||
|
|
||||||
|
# image_generator.py
|
||||||
|
@self.bot.command()
|
||||||
|
async def generate_image(ctx, *, prompt: str):
|
||||||
|
"""generates images based on the provided prompt"""
|
||||||
|
await ctx.send(f"generating images for prompt: `{prompt}`...")
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
|
# initialize a future object for the dalle instance
|
||||||
|
model_instance = dalle3()
|
||||||
|
future = loop.run_in_executor(Executor, model_instance.run, prompt)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# wait for the dalle request to complete, with a timeout of 60 seconds
|
||||||
|
await asyncio.wait_for(future, timeout=300)
|
||||||
|
print("done generating images!")
|
||||||
|
|
||||||
|
# list all files in the save_directory
|
||||||
|
all_files = [
|
||||||
|
os.path.join(root, file)
|
||||||
|
for root, _, files in os.walk(os.environ("SAVE_DIRECTORY"))
|
||||||
|
for file in files
|
||||||
|
]
|
||||||
|
|
||||||
|
# sort files by their creation time (latest first)
|
||||||
|
sorted_files = sorted(all_files, key=os.path.getctime, reverse=True)
|
||||||
|
|
||||||
|
# get the 4 most recent files
|
||||||
|
latest_files = sorted_files[:4]
|
||||||
|
print(f"sending {len(latest_files)} images to discord...")
|
||||||
|
|
||||||
|
# send all the latest images in a single message
|
||||||
|
storage_service = os.environ(
|
||||||
|
"STORAGE_SERVICE"
|
||||||
|
) # "https://storage.googleapis.com/your-bucket-name/
|
||||||
|
await ctx.send(
|
||||||
|
files=[
|
||||||
|
storage_service.upload(filepath) for filepath in latest_files
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
except asyncio.timeouterror:
|
||||||
|
await ctx.send(
|
||||||
|
"the request took too long! it might have been censored or you're out of boosts. please try entering the prompt again."
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
await ctx.send(f"an error occurred: {e}")
|
||||||
|
|
||||||
|
@self.bot.command()
|
||||||
|
async def send_text(ctx, *, text: str, use_agent: bool = True):
|
||||||
|
"""sends the provided text to the worker and returns the response"""
|
||||||
|
if use_agent:
|
||||||
|
response = self.agent.run(text)
|
||||||
|
else:
|
||||||
|
response = self.llm.run(text)
|
||||||
|
await ctx.send(response)
|
||||||
|
|
||||||
|
def add_command(self, name, func):
|
||||||
|
@self.bot.command()
|
||||||
|
async def command(ctx, *args):
|
||||||
|
reponse = func(*args)
|
||||||
|
await ctx.send(responses)
|
||||||
|
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
self.bot.run("DISCORD_TOKEN")
|
@ -0,0 +1,6 @@
|
|||||||
|
from swarms.models.bing_chat import BingChat
|
||||||
|
# Initialize the EdgeGPTModel
|
||||||
|
bing = BingChat(cookies_path="./cookies.json")
|
||||||
|
task = "generate topics for PositiveMed.com,: 1. Monitor Health Trends: Scan Google Alerts, authoritative health websites, and social media for emerging health, wellness, and medical discussions. 2. Keyword Research: Utilize tools like SEMrush to identify keywords with moderate to high search volume and low competition. Focus on long-tail, conversational keywords. 3. Analyze Site Data: Review PositiveMed's analytics to pinpoint popular articles and areas lacking recent content. 4. Crowdsourcing: Gather topic suggestions from the brand's audience and internal team, ensuring alignment with PositiveMed's mission. 5. Topic Evaluation: Assess topics for audience relevance, uniqueness, brand fit, current relevance, and SEO potential. 6. Tone and Style: Ensure topics can be approached with an educational, empowering, and ethical tone, in line with the brand's voice. Use this framework to generate a list of potential topics that cater to PositiveMed's audience while staying true to its brand ethos. Find trending topics for slowing and reversing aging think step by step and o into as much detail as possible"
|
||||||
|
response = bing(task)
|
||||||
|
print(response)
|
@ -0,0 +1,15 @@
|
|||||||
|
from swarms.models.bing_chat import BingChat
|
||||||
|
from swarms.workers.worker import Worker
|
||||||
|
from swarms.tools.autogpt import EdgeGPTTool, tool
|
||||||
|
from swarms.models import OpenAIChat
|
||||||
|
import os
|
||||||
|
|
||||||
|
load_dotenv("../.env")
|
||||||
|
auth_cookie = os.environ.get("AUTH_COOKIE")
|
||||||
|
auth_cookie_SRCHHPGUSR = os.environ.get("AUTH_COOKIE_SRCHHPGUSR")
|
||||||
|
|
||||||
|
# Initialize the EdgeGPTModel
|
||||||
|
bing = BingChat(cookies_path="./cookies.json", auth_cookie_SRCHHPGUSR)
|
||||||
|
task = "generate topics for PositiveMed.com,: 1. Monitor Health Trends: Scan Google Alerts, authoritative health websites, and social media for emerging health, wellness, and medical discussions. 2. Keyword Research: Utilize tools like SEMrush to identify keywords with moderate to high search volume and low competition. Focus on long-tail, conversational keywords. 3. Analyze Site Data: Review PositiveMed's analytics to pinpoint popular articles and areas lacking recent content. 4. Crowdsourcing: Gather topic suggestions from the brand's audience and internal team, ensuring alignment with PositiveMed's mission. 5. Topic Evaluation: Assess topics for audience relevance, uniqueness, brand fit, current relevance, and SEO potential. 6. Tone and Style: Ensure topics can be approached with an educational, empowering, and ethical tone, in line with the brand's voice. Use this framework to generate a list of potential topics that cater to PositiveMed's audience while staying true to its brand ethos. Find trending topics for slowing and reversing aging think step by step and o into as much detail as possible"
|
||||||
|
|
||||||
|
bing(task)
|
@ -0,0 +1,15 @@
|
|||||||
|
import os
|
||||||
|
from swarms.models.bing_chat import BingChat
|
||||||
|
from apps.discord import Bot
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
# Initialize the EdgeGPTModel
|
||||||
|
cookie = os.environ.get("BING_COOKIE")
|
||||||
|
auth = os.environ.get("AUTH_COOKIE")
|
||||||
|
bing = BingChat(cookies_path="./cookies.json")
|
||||||
|
|
||||||
|
bot = Bot(llm=bing)
|
||||||
|
bot.generate_image(imggen=bing.create_img(auth_cookie=cookie, auth_cookie_SRCHHPGUSR=auth))
|
||||||
|
bot.send_text(use_agent=False)
|
@ -1,32 +0,0 @@
|
|||||||
from swarms.models.bing_chat import BingChat
|
|
||||||
from swarms.workers.worker import Worker
|
|
||||||
from swarms.tools.autogpt import EdgeGPTTool, tool
|
|
||||||
from swarms.models import OpenAIChat
|
|
||||||
import os
|
|
||||||
|
|
||||||
api_key = os.getenv("OPENAI_API_KEY")
|
|
||||||
|
|
||||||
# Initialize the EdgeGPTModel
|
|
||||||
edgegpt = BingChat(cookies_path="./cookies.txt")
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
def edgegpt(task: str = None):
|
|
||||||
"""A tool to run infrence on the EdgeGPT Model"""
|
|
||||||
return EdgeGPTTool.run(task)
|
|
||||||
|
|
||||||
|
|
||||||
# Initialize the language model,
|
|
||||||
# This model can be swapped out with Anthropic, ETC, Huggingface Models like Mistral, ETC
|
|
||||||
llm = OpenAIChat(
|
|
||||||
openai_api_key=api_key,
|
|
||||||
temperature=0.5,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Initialize the Worker with the custom tool
|
|
||||||
worker = Worker(llm=llm, ai_name="EdgeGPT Worker", external_tools=[edgegpt])
|
|
||||||
|
|
||||||
# Use the worker to process a task
|
|
||||||
task = "Hello, my name is ChatGPT"
|
|
||||||
response = worker.run(task)
|
|
||||||
print(response)
|
|
@ -0,0 +1,29 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from swarms.models.revgptV4 import RevChatGPTModelv4
|
||||||
|
from swarms.models.revgptV1 import RevChatGPTModelv1
|
||||||
|
|
||||||
|
root_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||||
|
sys.path.append(root_dir)
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"model": os.getenv("REVGPT_MODEL"),
|
||||||
|
"plugin_ids": [os.getenv("REVGPT_PLUGIN_IDS")],
|
||||||
|
"disable_history": os.getenv("REVGPT_DISABLE_HISTORY") == "True",
|
||||||
|
"PUID": os.getenv("REVGPT_PUID"),
|
||||||
|
"unverified_plugin_domains": [os.getenv("REVGPT_UNVERIFIED_PLUGIN_DOMAINS")],
|
||||||
|
}
|
||||||
|
|
||||||
|
# For v1 model
|
||||||
|
model = RevChatGPTModelv1(access_token=os.getenv("ACCESS_TOKEN"), **config)
|
||||||
|
# model = RevChatGPTModelv4(access_token=os.getenv("ACCESS_TOKEN"), **config)
|
||||||
|
|
||||||
|
# For v3 model
|
||||||
|
# model = RevChatGPTModel(access_token=os.getenv("OPENAI_API_KEY"), **config)
|
||||||
|
|
||||||
|
task = "Write a cli snake game"
|
||||||
|
response = model.run(task)
|
||||||
|
print(response)
|
@ -0,0 +1,67 @@
|
|||||||
|
"""Bing-Chat model by Micorsoft"""
|
||||||
|
import os
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from EdgeGPT.EdgeGPT import Chatbot, ConversationStyle
|
||||||
|
from EdgeGPT.EdgeUtils import Cookie, ImageQuery, Query
|
||||||
|
from EdgeGPT.ImageGen import ImageGen
|
||||||
|
|
||||||
|
|
||||||
|
class BingChat:
|
||||||
|
"""
|
||||||
|
EdgeGPT model by OpenAI
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
cookies_path : str
|
||||||
|
Path to the cookies.json necessary for authenticating with EdgeGPT
|
||||||
|
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
>>> edgegpt = BingChat(cookies_path="./path/to/cookies.json")
|
||||||
|
>>> response = edgegpt("Hello, my name is ChatGPT")
|
||||||
|
>>> image_path = edgegpt.create_img("Sunset over mountains")
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, cookies_path: str = None):
|
||||||
|
self.cookies = json.loads(open(cookies_path, encoding="utf-8").read())
|
||||||
|
self.bot = asyncio.run(Chatbot.create(cookies=self.cookies))
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self, prompt: str, style: ConversationStyle = ConversationStyle.creative
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Get a text response using the EdgeGPT model based on the provided prompt.
|
||||||
|
"""
|
||||||
|
response = asyncio.run(
|
||||||
|
self.bot.ask(
|
||||||
|
prompt=prompt, conversation_style=style, simplify_response=True
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return response["text"]
|
||||||
|
|
||||||
|
def create_img(
|
||||||
|
self, prompt: str, output_dir: str = "./output", auth_cookie: str = None, auth_cookie_SRCHHPGUSR: str = None
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Generate an image based on the provided prompt and save it in the given output directory.
|
||||||
|
Returns the path of the generated image.
|
||||||
|
"""
|
||||||
|
if not auth_cookie:
|
||||||
|
raise ValueError("Auth cookie is required for image generation.")
|
||||||
|
|
||||||
|
image_generator = ImageGen(auth_cookie, auth_cookie_SRCHHPGUSR, quiet=True, )
|
||||||
|
images = image_generator.get_images(prompt)
|
||||||
|
image_generator.save_images(images, output_dir=output_dir)
|
||||||
|
|
||||||
|
return Path(output_dir) / images[0]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def set_cookie_dir_path(path: str):
|
||||||
|
"""
|
||||||
|
Set the directory path for managing cookies.
|
||||||
|
"""
|
||||||
|
Cookie.dir_path = Path(path)
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,55 @@
|
|||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
import openai
|
||||||
|
import ray
|
||||||
|
import uvicorn
|
||||||
|
from vllm.entrypoints import api_server as vllm_api_server
|
||||||
|
from vllm.entrypoints.openai import api_server as openai_api_server
|
||||||
|
from skypilot import SkyPilot
|
||||||
|
|
||||||
|
class VLLMModel:
|
||||||
|
def __init__(self, model_name="facebook/opt-125m", tensor_parallel_size=1):
|
||||||
|
self.model_name = model_name
|
||||||
|
self.tensor_parallel_size = tensor_parallel_size
|
||||||
|
self.model = LLM(model_name, tensor_parallel_size=tensor_parallel_size)
|
||||||
|
self.temperature = 1.0
|
||||||
|
self.max_tokens = None
|
||||||
|
self.sampling_params = SamplingParams(temperature=self.temperature)
|
||||||
|
|
||||||
|
def generate_text(self, prompt: str) -> str:
|
||||||
|
output = self.model.generate([prompt], self.sampling_params)
|
||||||
|
return output[0].outputs[0].text
|
||||||
|
|
||||||
|
def set_temperature(self, value: float):
|
||||||
|
self.temperature = value
|
||||||
|
self.sampling_params = SamplingParams(temperature=self.temperature)
|
||||||
|
|
||||||
|
def set_max_tokens(self, value: int):
|
||||||
|
self.max_tokens = value
|
||||||
|
self.sampling_params = SamplingParams(temperature=self.temperature, max_tokens=self.max_tokens)
|
||||||
|
|
||||||
|
def offline_batched_inference(self, prompts: list) -> list:
|
||||||
|
outputs = self.model.generate(prompts, self.sampling_params)
|
||||||
|
return [output.outputs[0].text for output in outputs]
|
||||||
|
|
||||||
|
def start_api_server(self):
|
||||||
|
uvicorn.run(vllm_api_server.app, host="0.0.0.0", port=8000)
|
||||||
|
|
||||||
|
def start_openai_compatible_server(self):
|
||||||
|
uvicorn.run(openai_api_server.app, host="0.0.0.0", port=8000)
|
||||||
|
|
||||||
|
def query_openai_compatible_server(self, prompt: str):
|
||||||
|
openai.api_key = "EMPTY"
|
||||||
|
openai.api_base = "http://localhost:8000/v1"
|
||||||
|
completion = openai.Completion.create(model=self.model_name, prompt=prompt)
|
||||||
|
return completion
|
||||||
|
|
||||||
|
def distributed_inference(self, prompt: str):
|
||||||
|
ray.init()
|
||||||
|
self.model = LLM(self.model_name, tensor_parallel_size=self.tensor_parallel_size)
|
||||||
|
output = self.model.generate(prompt, self.sampling_params)
|
||||||
|
ray.shutdown()
|
||||||
|
return output[0].outputs[0].text
|
||||||
|
|
||||||
|
def run_on_cloud_with_skypilot(self, yaml_file):
|
||||||
|
sky = SkyPilot()
|
||||||
|
sky.launch(yaml_file)
|
Loading…
Reference in new issue