MultiModalLlava

pull/68/head^2
Kye 1 year ago
parent fe54e84741
commit af93db7a34

@ -7,10 +7,11 @@ from invoke import Executor
from dotenv import load_dotenv from dotenv import load_dotenv
from discord.ext import commands from discord.ext import commands
class Bot: class Bot:
def __init__(self, agent, llm, command_prefix="!"): def __init__(self, agent, llm, command_prefix="!"):
load_dotenv() load_dotenv()
intents = discord.intents.default() intents = discord.intents.default()
intents.messages = True intents.messages = True
intents.guilds = True intents.guilds = True
@ -20,22 +21,19 @@ class Bot:
# setup # setup
self.llm = llm self.llm = llm
self.agent = agent self.agent = agent
self. bot = commands.bot(command_prefix="!", intents=intents) self.bot = commands.bot(command_prefix="!", intents=intents)
self.discord_token = os.getenv("DISCORD_TOKEN") self.discord_token = os.getenv("DISCORD_TOKEN")
self.storage_service = os.getenv("STORAGE_SERVICE") self.storage_service = os.getenv("STORAGE_SERVICE")
@self.bot.event @self.bot.event
async def on_ready(): async def on_ready():
print(f"we have logged in as {self.bot.user}") print(f"we have logged in as {self.bot.user}")
@self.bot.command() @self.bot.command()
async def greet(ctx): async def greet(ctx):
"""greets the user.""" """greets the user."""
await ctx.send(f"hello, {ctx.author.name}!") await ctx.send(f"hello, {ctx.author.name}!")
@self.bot.command() @self.bot.command()
async def help_me(ctx): async def help_me(ctx):
"""provides a list of commands and their descriptions.""" """provides a list of commands and their descriptions."""
@ -77,13 +75,13 @@ class Bot:
"""starts listening to voice in the voice channel that the bot is in.""" """starts listening to voice in the voice channel that the bot is in."""
if ctx.voice_client: if ctx.voice_client:
# create a wavesink to record the audio # create a wavesink to record the audio
sink = discord.sinks.wavesink('audio.wav') sink = discord.sinks.wavesink("audio.wav")
# start recording # start recording
ctx.voice_client.start_recording(sink) ctx.voice_client.start_recording(sink)
await ctx.send("started listening and recording.") await ctx.send("started listening and recording.")
else: else:
await ctx.send("i am not in a voice channel!") await ctx.send("i am not in a voice channel!")
# image_generator.py # image_generator.py
@self.bot.command() @self.bot.command()
async def generate_image(ctx, *, prompt: str): async def generate_image(ctx, *, prompt: str):
@ -101,7 +99,11 @@ class Bot:
print("done generating images!") print("done generating images!")
# list all files in the save_directory # list all files in the save_directory
all_files = [os.path.join(root, file) for root, _, files in os.walk(os.environ("SAVE_DIRECTORY")) for file in files] all_files = [
os.path.join(root, file)
for root, _, files in os.walk(os.environ("SAVE_DIRECTORY"))
for file in files
]
# sort files by their creation time (latest first) # sort files by their creation time (latest first)
sorted_files = sorted(all_files, key=os.path.getctime, reverse=True) sorted_files = sorted(all_files, key=os.path.getctime, reverse=True)
@ -111,11 +113,19 @@ class Bot:
print(f"sending {len(latest_files)} images to discord...") print(f"sending {len(latest_files)} images to discord...")
# send all the latest images in a single message # send all the latest images in a single message
storage_service = os.environ("STORAGE_SERVICE") # "https://storage.googleapis.com/your-bucket-name/ storage_service = os.environ(
await ctx.send(files=[storage_service.upload(filepath) for filepath in latest_files]) "STORAGE_SERVICE"
) # "https://storage.googleapis.com/your-bucket-name/
await ctx.send(
files=[
storage_service.upload(filepath) for filepath in latest_files
]
)
except asyncio.timeouterror: except asyncio.timeouterror:
await ctx.send("the request took too long! it might have been censored or you're out of boosts. please try entering the prompt again.") await ctx.send(
"the request took too long! it might have been censored or you're out of boosts. please try entering the prompt again."
)
except Exception as e: except Exception as e:
await ctx.send(f"an error occurred: {e}") await ctx.send(f"an error occurred: {e}")
@ -127,12 +137,13 @@ class Bot:
else: else:
response = self.llm.run(text) response = self.llm.run(text)
await ctx.send(response) await ctx.send(response)
def add_command(self, name, func): def add_command(self, name, func):
@self.bot.command() @self.bot.command()
async def command(ctx, *args): async def command(ctx, *args):
reponse = func(*args) reponse = func(*args)
await ctx.send(responses) await ctx.send(responses)
def run(self) :
def run(self):
self.bot.run("DISCORD_TOKEN") self.bot.run("DISCORD_TOKEN")

@ -11,4 +11,3 @@ task = "What were the winning boston marathon times for the past 5 years (ending
bot.send_text(task) bot.send_text(task)
bot.run() bot.run()

@ -15,7 +15,8 @@ from swarms.models.layoutlm_document_qa import LayoutLMDocumentQA
# from swarms.models.fuyu import Fuyu # Not working, wait until they update # from swarms.models.fuyu import Fuyu # Not working, wait until they update
import sys import sys
log_file = open('stderr_log.txt', 'w')
log_file = open("stderr_log.txt", "w")
sys.stderr = log_file sys.stderr = log_file

@ -0,0 +1,79 @@
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
class MultiModalLlava:
"""
LLava Model
Args:
model_name_or_path: The model name or path to the model
revision: The revision of the model to use
device: The device to run the model on
max_new_tokens: The maximum number of tokens to generate
do_sample: Whether or not to use sampling
temperature: The temperature of the sampling
top_p: The top p value for sampling
top_k: The top k value for sampling
repetition_penalty: The repetition penalty for sampling
device_map: The device map to use
Methods:
__call__: Call the model
chat: Interactive chat in terminal
Example:
>>> from swarms.models.llava import LlavaModel
>>> model = LlavaModel(device="cpu")
>>> model("Hello, I am a robot.")
"""
def __init__(
self,
model_name_or_path="TheBloke/llava-v1.5-13B-GPTQ",
revision="main",
device="cuda",
max_new_tokens=512,
do_sample=True,
temperature=0.7,
top_p=0.95,
top_k=40,
repetition_penalty=1.1,
device_map: str = "auto"
):
self.device = device
self.model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
device_map=device_map,
trust_remote_code=False,
revision=revision,
).to(self.device)
self.tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path, use_fast=True
)
self.pipe = pipeline(
"text-generation",
model=self.model,
tokenizer=self.tokenizer,
max_new_tokens=max_new_tokens,
do_sample=do_sample,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
device=0 if self.device == "cuda" else -1,
)
def __call__(self, prompt):
"""Call the model"""
return self.pipe(prompt)[0]["generated_text"]
def chat(self):
"""Interactive chat in terminal"""
print("Starting chat with LlavaModel. Type 'exit' to end the session.")
while True:
user_input = input("You: ")
if user_input.lower() == "exit":
break
response = self(user_input)
print(f"Model: {response}")

@ -1,22 +1,24 @@
import unittest import unittest
from unittest.mock import patch, Mock, MagicMock from unittest.mock import patch, Mock, MagicMock
from apps.discord import Bot # Replace 'Bot' with the name of the file containing your bot's code. from apps.discord import (
Bot,
) # Replace 'Bot' with the name of the file containing your bot's code.
class TestBot(unittest.TestCase):
class TestBot(unittest.TestCase):
def setUp(self): def setUp(self):
self.llm_mock = Mock() self.llm_mock = Mock()
self.agent_mock = Mock() self.agent_mock = Mock()
self.bot = Bot(agent=self.agent_mock, llm=self.llm_mock) self.bot = Bot(agent=self.agent_mock, llm=self.llm_mock)
@patch('Bot.load_dotenv') # Mocking the `load_dotenv` function call. @patch("Bot.load_dotenv") # Mocking the `load_dotenv` function call.
def test_initialization(self, mock_load_dotenv): def test_initialization(self, mock_load_dotenv):
self.assertIsNotNone(self.bot.bot) self.assertIsNotNone(self.bot.bot)
self.assertEqual(self.bot.agent, self.agent_mock) self.assertEqual(self.bot.agent, self.agent_mock)
self.assertEqual(self.bot.llm, self.llm_mock) self.assertEqual(self.bot.llm, self.llm_mock)
mock_load_dotenv.assert_called_once() mock_load_dotenv.assert_called_once()
@patch('Bot.commands.bot') @patch("Bot.commands.bot")
def test_greet(self, mock_bot): def test_greet(self, mock_bot):
ctx_mock = Mock() ctx_mock = Mock()
ctx_mock.author.name = "TestUser" ctx_mock.author.name = "TestUser"
@ -26,7 +28,7 @@ class TestBot(unittest.TestCase):
# Similarly, you can add tests for other commands. # Similarly, you can add tests for other commands.
@patch('Bot.commands.bot') @patch("Bot.commands.bot")
def test_help_me(self, mock_bot): def test_help_me(self, mock_bot):
ctx_mock = Mock() ctx_mock = Mock()
self.bot.bot.clear() self.bot.bot.clear()
@ -34,7 +36,7 @@ class TestBot(unittest.TestCase):
# Verify the help text was sent. You can check for a substring to make it shorter. # Verify the help text was sent. You can check for a substring to make it shorter.
ctx_mock.send.assert_called() ctx_mock.send.assert_called()
@patch('Bot.commands.bot') @patch("Bot.commands.bot")
def test_on_command_error(self, mock_bot): def test_on_command_error(self, mock_bot):
ctx_mock = Mock() ctx_mock = Mock()
error_mock = Mock() error_mock = Mock()
@ -52,5 +54,6 @@ class TestBot(unittest.TestCase):
# You can add more tests for other commands and functionalities. # You can add more tests for other commands and functionalities.
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

Loading…
Cancel
Save