Former-commit-id: cc041dcf77
group-chat
Kye 1 year ago
parent 39a8d7672f
commit e8efb3cbde

@ -1,7 +1,7 @@
from swarms.agents import MultiModalAgent from swarms.agents import MultiModalAgent
load_dict = { load_dict = {
"ImageCaptioning": "cuda" "ImageCaptioning": "cpu"
} }
node = MultiModalAgent(load_dict) node = MultiModalAgent(load_dict)

@ -28,6 +28,32 @@ colored
addict
albumentations
basicsr
controlnet-aux
diffusers
einops
imageio
imageio-ffmpeg
invisible-watermark
kornia
numpy
omegaconf
open_clip_torch
openai
opencv-python
prettytable
safetensors
streamlit
test-tube
timm
torchmetrics
transformers
webdataset
yapf
mkdocs mkdocs
mkdocs-material mkdocs-material
mkdocs-glightbox mkdocs-glightbox

@ -1,5 +1,4 @@
import os import os
import gradio as gr
import random import random
import torch import torch
import cv2 import cv2

@ -17,12 +17,11 @@ from langchain.chains.qa_with_sources.loading import BaseCombineDocumentsChain
from langchain.chat_models import ChatOpenAI from langchain.chat_models import ChatOpenAI
from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.tools import BaseTool from langchain.tools import BaseTool
from langchain.tools.file_management.read import ReadFileTool
from langchain.tools.file_management.write import WriteFileTool
from pydantic import Field from pydantic import Field
from swarms.utils.logger import logger from swarms.utils.logger import logger
from langchain.tools.file_management.write import WriteFileTool
from langchain.tools.file_management.read import ReadFileTool
@contextmanager @contextmanager
@ -141,6 +140,7 @@ query_website_tool = WebpageQATool(qa_chain=load_qa_with_sources_chain(llm))
# code_intepret = CodeInterpreter() # code_intepret = CodeInterpreter()
import interpreter import interpreter
@tool @tool
def compile(task: str): def compile(task: str):
""" """
@ -169,41 +169,42 @@ def compile(task: str):
# mm model workers # mm model workers
import torch # import torch
from PIL import Image # from PIL import Image
from transformers import ( # from transformers import (
BlipForQuestionAnswering, # BlipForQuestionAnswering,
BlipProcessor, # BlipProcessor,
) # )
@tool
def VQAinference(self, inputs): # @tool
""" # def VQAinference(self, inputs):
Answer Question About The Image, VQA Multi-Modal Worker agent # """
description="useful when you need an answer for a question based on an image. " # Answer Question About The Image, VQA Multi-Modal Worker agent
"like: what is the background color of the last image, how many cats in this figure, what is in this figure. " # description="useful when you need an answer for a question based on an image. "
"The input to this tool should be a comma separated string of two, representing the image_path and the question", # "like: what is the background color of the last image, how many cats in this figure, what is in this figure. "
# "The input to this tool should be a comma separated string of two, representing the image_path and the question",
"""
device = "cuda:0" # """
torch_dtype = torch.float16 if "cuda" in device else torch.float32 # device = "cuda:0"
processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base") # torch_dtype = torch.float16 if "cuda" in device else torch.float32
model = BlipForQuestionAnswering.from_pretrained( # processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
"Salesforce/blip-vqa-base", torch_dtype=torch_dtype # model = BlipForQuestionAnswering.from_pretrained(
).to(device) # "Salesforce/blip-vqa-base", torch_dtype=torch_dtype
# ).to(device)
image_path, question = inputs.split(",")
raw_image = Image.open(image_path).convert("RGB") # image_path, question = inputs.split(",")
inputs = processor(raw_image, question, return_tensors="pt").to( # raw_image = Image.open(image_path).convert("RGB")
device, torch_dtype # inputs = processor(raw_image, question, return_tensors="pt").to(
) # device, torch_dtype
out = model.generate(**inputs) # )
answer = processor.decode(out[0], skip_special_tokens=True) # out = model.generate(**inputs)
# answer = processor.decode(out[0], skip_special_tokens=True)
logger.debug(
f"\nProcessed VisualQuestionAnswering, Input Image: {image_path}, Input Question: {question}, " # logger.debug(
f"Output Answer: {answer}" # f"\nProcessed VisualQuestionAnswering, Input Image: {image_path}, Input Question: {question}, "
) # f"Output Answer: {answer}"
# )
return answer
# return answer

@ -174,7 +174,7 @@ class Worker:
query_website_tool, query_website_tool,
HumanInputRun(), HumanInputRun(),
compile, compile,
VQAinference # VQAinference
] ]
if external_tools is not None: if external_tools is not None:
self.tools.extend(external_tools) self.tools.extend(external_tools)

Loading…
Cancel
Save