MultiModalLlava

Former-commit-id: af93db7a34
bing-chat^2
Kye 1 year ago
parent ba1b292ce7
commit 36a78e2561

@ -7,6 +7,7 @@ from invoke import Executor
from dotenv import load_dotenv
from discord.ext import commands
class Bot:
def __init__(self, agent, llm, command_prefix="!"):
load_dotenv()
@ -20,22 +21,19 @@ class Bot:
# setup
self.llm = llm
self.agent = agent
self. bot = commands.bot(command_prefix="!", intents=intents)
self.bot = commands.bot(command_prefix="!", intents=intents)
self.discord_token = os.getenv("DISCORD_TOKEN")
self.storage_service = os.getenv("STORAGE_SERVICE")
@self.bot.event
async def on_ready():
print(f"we have logged in as {self.bot.user}")
@self.bot.command()
async def greet(ctx):
"""greets the user."""
await ctx.send(f"hello, {ctx.author.name}!")
@self.bot.command()
async def help_me(ctx):
"""provides a list of commands and their descriptions."""
@ -77,7 +75,7 @@ class Bot:
"""starts listening to voice in the voice channel that the bot is in."""
if ctx.voice_client:
# create a wavesink to record the audio
sink = discord.sinks.wavesink('audio.wav')
sink = discord.sinks.wavesink("audio.wav")
# start recording
ctx.voice_client.start_recording(sink)
await ctx.send("started listening and recording.")
@ -101,7 +99,11 @@ class Bot:
print("done generating images!")
# list all files in the save_directory
all_files = [os.path.join(root, file) for root, _, files in os.walk(os.environ("SAVE_DIRECTORY")) for file in files]
all_files = [
os.path.join(root, file)
for root, _, files in os.walk(os.environ("SAVE_DIRECTORY"))
for file in files
]
# sort files by their creation time (latest first)
sorted_files = sorted(all_files, key=os.path.getctime, reverse=True)
@ -111,11 +113,19 @@ class Bot:
print(f"sending {len(latest_files)} images to discord...")
# send all the latest images in a single message
storage_service = os.environ("STORAGE_SERVICE") # "https://storage.googleapis.com/your-bucket-name/
await ctx.send(files=[storage_service.upload(filepath) for filepath in latest_files])
storage_service = os.environ(
"STORAGE_SERVICE"
) # "https://storage.googleapis.com/your-bucket-name/
await ctx.send(
files=[
storage_service.upload(filepath) for filepath in latest_files
]
)
except asyncio.timeouterror:
await ctx.send("the request took too long! it might have been censored or you're out of boosts. please try entering the prompt again.")
await ctx.send(
"the request took too long! it might have been censored or you're out of boosts. please try entering the prompt again."
)
except Exception as e:
await ctx.send(f"an error occurred: {e}")
@ -134,5 +144,6 @@ class Bot:
reponse = func(*args)
await ctx.send(responses)
def run(self) :
def run(self):
self.bot.run("DISCORD_TOKEN")

@ -11,4 +11,3 @@ task = "What were the winning boston marathon times for the past 5 years (ending
bot.send_text(task)
bot.run()

@ -15,7 +15,8 @@ from swarms.models.layoutlm_document_qa import LayoutLMDocumentQA
# from swarms.models.fuyu import Fuyu # Not working, wait until they update
import sys
log_file = open('stderr_log.txt', 'w')
log_file = open("stderr_log.txt", "w")
sys.stderr = log_file

@ -0,0 +1,79 @@
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
class MultiModalLlava:
"""
LLava Model
Args:
model_name_or_path: The model name or path to the model
revision: The revision of the model to use
device: The device to run the model on
max_new_tokens: The maximum number of tokens to generate
do_sample: Whether or not to use sampling
temperature: The temperature of the sampling
top_p: The top p value for sampling
top_k: The top k value for sampling
repetition_penalty: The repetition penalty for sampling
device_map: The device map to use
Methods:
__call__: Call the model
chat: Interactive chat in terminal
Example:
>>> from swarms.models.llava import LlavaModel
>>> model = LlavaModel(device="cpu")
>>> model("Hello, I am a robot.")
"""
def __init__(
self,
model_name_or_path="TheBloke/llava-v1.5-13B-GPTQ",
revision="main",
device="cuda",
max_new_tokens=512,
do_sample=True,
temperature=0.7,
top_p=0.95,
top_k=40,
repetition_penalty=1.1,
device_map: str = "auto"
):
self.device = device
self.model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
device_map=device_map,
trust_remote_code=False,
revision=revision,
).to(self.device)
self.tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path, use_fast=True
)
self.pipe = pipeline(
"text-generation",
model=self.model,
tokenizer=self.tokenizer,
max_new_tokens=max_new_tokens,
do_sample=do_sample,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
device=0 if self.device == "cuda" else -1,
)
def __call__(self, prompt):
"""Call the model"""
return self.pipe(prompt)[0]["generated_text"]
def chat(self):
"""Interactive chat in terminal"""
print("Starting chat with LlavaModel. Type 'exit' to end the session.")
while True:
user_input = input("You: ")
if user_input.lower() == "exit":
break
response = self(user_input)
print(f"Model: {response}")

@ -1,22 +1,24 @@
import unittest
from unittest.mock import patch, Mock, MagicMock
from apps.discord import Bot # Replace 'Bot' with the name of the file containing your bot's code.
from apps.discord import (
Bot,
) # Replace 'Bot' with the name of the file containing your bot's code.
class TestBot(unittest.TestCase):
class TestBot(unittest.TestCase):
def setUp(self):
self.llm_mock = Mock()
self.agent_mock = Mock()
self.bot = Bot(agent=self.agent_mock, llm=self.llm_mock)
@patch('Bot.load_dotenv') # Mocking the `load_dotenv` function call.
@patch("Bot.load_dotenv") # Mocking the `load_dotenv` function call.
def test_initialization(self, mock_load_dotenv):
self.assertIsNotNone(self.bot.bot)
self.assertEqual(self.bot.agent, self.agent_mock)
self.assertEqual(self.bot.llm, self.llm_mock)
mock_load_dotenv.assert_called_once()
@patch('Bot.commands.bot')
@patch("Bot.commands.bot")
def test_greet(self, mock_bot):
ctx_mock = Mock()
ctx_mock.author.name = "TestUser"
@ -26,7 +28,7 @@ class TestBot(unittest.TestCase):
# Similarly, you can add tests for other commands.
@patch('Bot.commands.bot')
@patch("Bot.commands.bot")
def test_help_me(self, mock_bot):
ctx_mock = Mock()
self.bot.bot.clear()
@ -34,7 +36,7 @@ class TestBot(unittest.TestCase):
# Verify the help text was sent. You can check for a substring to make it shorter.
ctx_mock.send.assert_called()
@patch('Bot.commands.bot')
@patch("Bot.commands.bot")
def test_on_command_error(self, mock_bot):
ctx_mock = Mock()
error_mock = Mock()
@ -52,5 +54,6 @@ class TestBot(unittest.TestCase):
# You can add more tests for other commands and functionalities.
if __name__ == "__main__":
unittest.main()

Loading…
Cancel
Save