From 36a78e2561bf6b391135de6eb8e76b266855e315 Mon Sep 17 00:00:00 2001 From: Kye Date: Mon, 23 Oct 2023 16:47:50 -0400 Subject: [PATCH] MultiModalLlava Former-commit-id: af93db7a3446632a9f822d2b2bdcaad258b9e5ba --- apps/discord.py | 39 +++++++++------ playground/apps/discord_example.py | 1 - swarms/models/__init__.py | 3 +- swarms/models/llava.py | 79 ++++++++++++++++++++++++++++++ tests/apps/discord.py | 15 +++--- 5 files changed, 115 insertions(+), 22 deletions(-) create mode 100644 swarms/models/llava.py diff --git a/apps/discord.py b/apps/discord.py index a03d0835..f605a108 100644 --- a/apps/discord.py +++ b/apps/discord.py @@ -7,10 +7,11 @@ from invoke import Executor from dotenv import load_dotenv from discord.ext import commands + class Bot: def __init__(self, agent, llm, command_prefix="!"): load_dotenv() - + intents = discord.intents.default() intents.messages = True intents.guilds = True @@ -20,22 +21,19 @@ class Bot: # setup self.llm = llm self.agent = agent - self. bot = commands.bot(command_prefix="!", intents=intents) + self.bot = commands.bot(command_prefix="!", intents=intents) self.discord_token = os.getenv("DISCORD_TOKEN") self.storage_service = os.getenv("STORAGE_SERVICE") - @self.bot.event async def on_ready(): print(f"we have logged in as {self.bot.user}") - @self.bot.command() async def greet(ctx): """greets the user.""" await ctx.send(f"hello, {ctx.author.name}!") - @self.bot.command() async def help_me(ctx): """provides a list of commands and their descriptions.""" @@ -77,13 +75,13 @@ class Bot: """starts listening to voice in the voice channel that the bot is in.""" if ctx.voice_client: # create a wavesink to record the audio - sink = discord.sinks.wavesink('audio.wav') + sink = discord.sinks.wavesink("audio.wav") # start recording ctx.voice_client.start_recording(sink) await ctx.send("started listening and recording.") else: await ctx.send("i am not in a voice channel!") - + # image_generator.py @self.bot.command() async def generate_image(ctx, *, prompt: str): @@ -101,7 +99,11 @@ class Bot: print("done generating images!") # list all files in the save_directory - all_files = [os.path.join(root, file) for root, _, files in os.walk(os.environ("SAVE_DIRECTORY")) for file in files] + all_files = [ + os.path.join(root, file) + for root, _, files in os.walk(os.environ("SAVE_DIRECTORY")) + for file in files + ] # sort files by their creation time (latest first) sorted_files = sorted(all_files, key=os.path.getctime, reverse=True) @@ -111,11 +113,19 @@ class Bot: print(f"sending {len(latest_files)} images to discord...") # send all the latest images in a single message - storage_service = os.environ("STORAGE_SERVICE") # "https://storage.googleapis.com/your-bucket-name/ - await ctx.send(files=[storage_service.upload(filepath) for filepath in latest_files]) + storage_service = os.environ( + "STORAGE_SERVICE" + ) # "https://storage.googleapis.com/your-bucket-name/ + await ctx.send( + files=[ + storage_service.upload(filepath) for filepath in latest_files + ] + ) except asyncio.timeouterror: - await ctx.send("the request took too long! it might have been censored or you're out of boosts. please try entering the prompt again.") + await ctx.send( + "the request took too long! it might have been censored or you're out of boosts. please try entering the prompt again." + ) except Exception as e: await ctx.send(f"an error occurred: {e}") @@ -127,12 +137,13 @@ class Bot: else: response = self.llm.run(text) await ctx.send(response) - + def add_command(self, name, func): @self.bot.command() async def command(ctx, *args): reponse = func(*args) await ctx.send(responses) - -def run(self) : + + +def run(self): self.bot.run("DISCORD_TOKEN") diff --git a/playground/apps/discord_example.py b/playground/apps/discord_example.py index 2010f71e..a3a90cf6 100644 --- a/playground/apps/discord_example.py +++ b/playground/apps/discord_example.py @@ -11,4 +11,3 @@ task = "What were the winning boston marathon times for the past 5 years (ending bot.send_text(task) bot.run() - diff --git a/swarms/models/__init__.py b/swarms/models/__init__.py index 85029c26..c12d9dda 100644 --- a/swarms/models/__init__.py +++ b/swarms/models/__init__.py @@ -15,7 +15,8 @@ from swarms.models.layoutlm_document_qa import LayoutLMDocumentQA # from swarms.models.fuyu import Fuyu # Not working, wait until they update import sys -log_file = open('stderr_log.txt', 'w') + +log_file = open("stderr_log.txt", "w") sys.stderr = log_file diff --git a/swarms/models/llava.py b/swarms/models/llava.py new file mode 100644 index 00000000..67c0e4a7 --- /dev/null +++ b/swarms/models/llava.py @@ -0,0 +1,79 @@ +from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline + + +class MultiModalLlava: + """ + LLava Model + + Args: + model_name_or_path: The model name or path to the model + revision: The revision of the model to use + device: The device to run the model on + max_new_tokens: The maximum number of tokens to generate + do_sample: Whether or not to use sampling + temperature: The temperature of the sampling + top_p: The top p value for sampling + top_k: The top k value for sampling + repetition_penalty: The repetition penalty for sampling + device_map: The device map to use + + Methods: + __call__: Call the model + chat: Interactive chat in terminal + + Example: + >>> from swarms.models.llava import LlavaModel + >>> model = LlavaModel(device="cpu") + >>> model("Hello, I am a robot.") + """ + def __init__( + self, + model_name_or_path="TheBloke/llava-v1.5-13B-GPTQ", + revision="main", + device="cuda", + max_new_tokens=512, + do_sample=True, + temperature=0.7, + top_p=0.95, + top_k=40, + repetition_penalty=1.1, + device_map: str = "auto" + ): + self.device = device + self.model = AutoModelForCausalLM.from_pretrained( + model_name_or_path, + device_map=device_map, + trust_remote_code=False, + revision=revision, + ).to(self.device) + + self.tokenizer = AutoTokenizer.from_pretrained( + model_name_or_path, use_fast=True + ) + self.pipe = pipeline( + "text-generation", + model=self.model, + tokenizer=self.tokenizer, + max_new_tokens=max_new_tokens, + do_sample=do_sample, + temperature=temperature, + top_p=top_p, + top_k=top_k, + repetition_penalty=repetition_penalty, + device=0 if self.device == "cuda" else -1, + ) + + def __call__(self, prompt): + """Call the model""" + return self.pipe(prompt)[0]["generated_text"] + + def chat(self): + """Interactive chat in terminal""" + print("Starting chat with LlavaModel. Type 'exit' to end the session.") + while True: + user_input = input("You: ") + if user_input.lower() == "exit": + break + response = self(user_input) + print(f"Model: {response}") + diff --git a/tests/apps/discord.py b/tests/apps/discord.py index 2e07e2b3..bc8daa80 100644 --- a/tests/apps/discord.py +++ b/tests/apps/discord.py @@ -1,22 +1,24 @@ import unittest from unittest.mock import patch, Mock, MagicMock -from apps.discord import Bot # Replace 'Bot' with the name of the file containing your bot's code. +from apps.discord import ( + Bot, +) # Replace 'Bot' with the name of the file containing your bot's code. -class TestBot(unittest.TestCase): +class TestBot(unittest.TestCase): def setUp(self): self.llm_mock = Mock() self.agent_mock = Mock() self.bot = Bot(agent=self.agent_mock, llm=self.llm_mock) - @patch('Bot.load_dotenv') # Mocking the `load_dotenv` function call. + @patch("Bot.load_dotenv") # Mocking the `load_dotenv` function call. def test_initialization(self, mock_load_dotenv): self.assertIsNotNone(self.bot.bot) self.assertEqual(self.bot.agent, self.agent_mock) self.assertEqual(self.bot.llm, self.llm_mock) mock_load_dotenv.assert_called_once() - @patch('Bot.commands.bot') + @patch("Bot.commands.bot") def test_greet(self, mock_bot): ctx_mock = Mock() ctx_mock.author.name = "TestUser" @@ -26,7 +28,7 @@ class TestBot(unittest.TestCase): # Similarly, you can add tests for other commands. - @patch('Bot.commands.bot') + @patch("Bot.commands.bot") def test_help_me(self, mock_bot): ctx_mock = Mock() self.bot.bot.clear() @@ -34,7 +36,7 @@ class TestBot(unittest.TestCase): # Verify the help text was sent. You can check for a substring to make it shorter. ctx_mock.send.assert_called() - @patch('Bot.commands.bot') + @patch("Bot.commands.bot") def test_on_command_error(self, mock_bot): ctx_mock = Mock() error_mock = Mock() @@ -52,5 +54,6 @@ class TestBot(unittest.TestCase): # You can add more tests for other commands and functionalities. + if __name__ == "__main__": unittest.main()