@ -0,0 +1,3 @@
|
|||||||
|
.env
|
||||||
|
__pycache__
|
||||||
|
.venv
|
@ -0,0 +1,27 @@
|
|||||||
|
OPENAI_API_KEY=""
|
||||||
|
WOLFRAM_ALPHA_APPID=""
|
||||||
|
ZAPIER_NLA_API_KEY=""
|
||||||
|
|
||||||
|
EVAL_PORT=8000
|
||||||
|
MODEL_NAME=""
|
||||||
|
CELERY_BROKER_URL=""
|
||||||
|
|
||||||
|
SERVER=""
|
||||||
|
USE_GPU=True
|
||||||
|
PLAYGROUND_DIR="playground"
|
||||||
|
|
||||||
|
OPENAI_API_KEY="your_openai_api_key_here"
|
||||||
|
LOG_LEVEL="INFO"
|
||||||
|
BOT_NAME="Orca"
|
||||||
|
|
||||||
|
WINEDB_HOST="your_winedb_host_here"
|
||||||
|
WINEDB_PASSWORD="your_winedb_password_here"
|
||||||
|
BING_SEARCH_URL="your_bing_search_url_here"
|
||||||
|
|
||||||
|
BING_SUBSCRIPTION_KEY="your_bing_subscription_key_here"
|
||||||
|
SERPAPI_API_KEY="your_serpapi_api_key_here"
|
||||||
|
IFTTTKey=""
|
||||||
|
|
||||||
|
BRAVE_API_KEY=""
|
||||||
|
SPOONACULAR_KEY=""
|
||||||
|
HF_API_KEY="Huggingface api key"
|
@ -0,0 +1,49 @@
|
|||||||
|
name: release
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
types:
|
||||||
|
- closed
|
||||||
|
branches:
|
||||||
|
- master
|
||||||
|
paths:
|
||||||
|
- 'pyproject.toml'
|
||||||
|
|
||||||
|
env:
|
||||||
|
POETRY_VERSION: "1.4.2"
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
if_release:
|
||||||
|
if: |
|
||||||
|
${{ github.event.pull_request.merged == true }}
|
||||||
|
&& ${{ contains(github.event.pull_request.labels.*.name, 'release') }}
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v3
|
||||||
|
- name: Install poetry
|
||||||
|
run: pipx install poetry==$POETRY_VERSION
|
||||||
|
- name: Set up Python 3.10
|
||||||
|
uses: actions/setup-python@v4
|
||||||
|
with:
|
||||||
|
python-version: "3.10"
|
||||||
|
cache: "poetry"
|
||||||
|
- name: Build project for distribution
|
||||||
|
run: poetry build
|
||||||
|
- name: Check Version
|
||||||
|
id: check-version
|
||||||
|
run: |
|
||||||
|
echo version=$(poetry version --short) >> $GITHUB_OUTPUT
|
||||||
|
- name: Create Release
|
||||||
|
uses: ncipollo/release-action@v1
|
||||||
|
with:
|
||||||
|
artifacts: "dist/*"
|
||||||
|
token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
draft: false
|
||||||
|
generateReleaseNotes: true
|
||||||
|
tag: v${{ steps.check-version.outputs.version }}
|
||||||
|
commit: master
|
||||||
|
- name: Publish to PyPI
|
||||||
|
env:
|
||||||
|
POETRY_PYPI_TOKEN_PYPI: ${{ secrets.PYPI_API_TOKEN }}
|
||||||
|
run: |
|
||||||
|
poetry publish
|
@ -0,0 +1,32 @@
|
|||||||
|
|
||||||
|
name: Upload Python Package
|
||||||
|
|
||||||
|
on:
|
||||||
|
release:
|
||||||
|
types: [published]
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
deploy:
|
||||||
|
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v3
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v3
|
||||||
|
with:
|
||||||
|
python-version: '3.x'
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
pip install build
|
||||||
|
- name: Build package
|
||||||
|
run: python -m build
|
||||||
|
- name: Publish package
|
||||||
|
uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
|
||||||
|
with:
|
||||||
|
user: __token__
|
||||||
|
password: ${{ secrets.PYPI_API_TOKEN }}
|
@ -0,0 +1,49 @@
|
|||||||
|
name: test
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [master]
|
||||||
|
pull_request:
|
||||||
|
workflow_dispatch:
|
||||||
|
|
||||||
|
env:
|
||||||
|
POETRY_VERSION: "1.4.2"
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
python-version:
|
||||||
|
- "3.8"
|
||||||
|
- "3.9"
|
||||||
|
- "3.10"
|
||||||
|
- "3.11"
|
||||||
|
test_type:
|
||||||
|
- "core"
|
||||||
|
- "extended"
|
||||||
|
name: Python ${{ matrix.python-version }} ${{ matrix.test_type }}
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v3
|
||||||
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
|
uses: "./.github/actions/poetry_setup"
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
poetry-version: "1.4.2"
|
||||||
|
cache-key: ${{ matrix.test_type }}
|
||||||
|
install-command: |
|
||||||
|
if [ "${{ matrix.test_type }}" == "core" ]; then
|
||||||
|
echo "Running core tests, installing dependencies with poetry..."
|
||||||
|
poetry install
|
||||||
|
else
|
||||||
|
echo "Running extended tests, installing dependencies with poetry..."
|
||||||
|
poetry install -E extended_testing
|
||||||
|
fi
|
||||||
|
- name: Run ${{matrix.test_type}} tests
|
||||||
|
run: |
|
||||||
|
if [ "${{ matrix.test_type }}" == "core" ]; then
|
||||||
|
make test
|
||||||
|
else
|
||||||
|
make extended_tests
|
||||||
|
fi
|
||||||
|
shell: bash
|
@ -0,0 +1,45 @@
|
|||||||
|
name: build
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [ main ]
|
||||||
|
pull_request:
|
||||||
|
branches: [ main ]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
|
||||||
|
build:
|
||||||
|
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v3
|
||||||
|
|
||||||
|
- name: Setup Python
|
||||||
|
uses: actions/setup-python@v4
|
||||||
|
with:
|
||||||
|
python-version: '3.10'
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: pip install -r requirements.txt
|
||||||
|
|
||||||
|
- name: Run Python unit tests
|
||||||
|
run: python3 -m unittest tests/swarms
|
||||||
|
|
||||||
|
- name: Verify that the Docker image for the action builds
|
||||||
|
run: docker build . --file Dockerfile
|
||||||
|
|
||||||
|
- name: Integration test 1
|
||||||
|
uses: ./
|
||||||
|
with:
|
||||||
|
input-one: something
|
||||||
|
input-two: true
|
||||||
|
|
||||||
|
- name: Integration test 2
|
||||||
|
uses: ./
|
||||||
|
with:
|
||||||
|
input-one: something else
|
||||||
|
input-two: false
|
||||||
|
|
||||||
|
- name: Verify integration test results
|
||||||
|
run: python3 -m unittest unittesting/swarms
|
@ -0,0 +1,12 @@
|
|||||||
|
__pycache__/
|
||||||
|
.venv/
|
||||||
|
|
||||||
|
.env
|
||||||
|
|
||||||
|
image/
|
||||||
|
audio/
|
||||||
|
video/
|
||||||
|
dataframe/
|
||||||
|
|
||||||
|
static/generated
|
||||||
|
swarms/__pycache__
|
@ -0,0 +1,248 @@
|
|||||||
|
# Contributing to Swarms
|
||||||
|
|
||||||
|
Hi there! Thank you for even being interested in contributing to Swarms.
|
||||||
|
As an open source project in a rapidly developing field, we are extremely open
|
||||||
|
to contributions, whether they be in the form of new features, improved infra, better documentation, or bug fixes.
|
||||||
|
|
||||||
|
## 🗺️ Guidelines
|
||||||
|
|
||||||
|
### 👩💻 Contributing Code
|
||||||
|
|
||||||
|
To contribute to this project, please follow a ["fork and pull request"](https://docs.github.com/en/get-started/quickstart/contributing-to-projects) workflow.
|
||||||
|
Please do not try to push directly to this repo unless you are maintainer.
|
||||||
|
|
||||||
|
Please follow the checked-in pull request template when opening pull requests. Note related issues and tag relevant
|
||||||
|
maintainers.
|
||||||
|
|
||||||
|
Pull requests cannot land without passing the formatting, linting and testing checks first. See
|
||||||
|
[Common Tasks](#-common-tasks) for how to run these checks locally.
|
||||||
|
|
||||||
|
It's essential that we maintain great documentation and testing. If you:
|
||||||
|
- Fix a bug
|
||||||
|
- Add a relevant unit or integration test when possible. These live in `tests/unit_tests` and `tests/integration_tests`.
|
||||||
|
- Make an improvement
|
||||||
|
- Update any affected example notebooks and documentation. These lives in `docs`.
|
||||||
|
- Update unit and integration tests when relevant.
|
||||||
|
- Add a feature
|
||||||
|
- Add a demo notebook in `docs/modules`.
|
||||||
|
- Add unit and integration tests.
|
||||||
|
|
||||||
|
We're a small, building-oriented team. If there's something you'd like to add or change, opening a pull request is the
|
||||||
|
best way to get our attention.
|
||||||
|
|
||||||
|
### 🚩GitHub Issues
|
||||||
|
|
||||||
|
Our [issues](https://github.com/kyegomez/Swarms/issues) page is kept up to date
|
||||||
|
with bugs, improvements, and feature requests.
|
||||||
|
|
||||||
|
There is a taxonomy of labels to help with sorting and discovery of issues of interest. Please use these to help
|
||||||
|
organize issues.
|
||||||
|
|
||||||
|
If you start working on an issue, please assign it to yourself.
|
||||||
|
|
||||||
|
If you are adding an issue, please try to keep it focused on a single, modular bug/improvement/feature.
|
||||||
|
If two issues are related, or blocking, please link them rather than combining them.
|
||||||
|
|
||||||
|
We will try to keep these issues as up to date as possible, though
|
||||||
|
with the rapid rate of develop in this field some may get out of date.
|
||||||
|
If you notice this happening, please let us know.
|
||||||
|
|
||||||
|
### 🙋Getting Help
|
||||||
|
|
||||||
|
Our goal is to have the simplest developer setup possible. Should you experience any difficulty getting setup, please
|
||||||
|
contact a maintainer! Not only do we want to help get you unblocked, but we also want to make sure that the process is
|
||||||
|
smooth for future contributors.
|
||||||
|
|
||||||
|
In a similar vein, we do enforce certain linting, formatting, and documentation standards in the codebase.
|
||||||
|
If you are finding these difficult (or even just annoying) to work with, feel free to contact a maintainer for help -
|
||||||
|
we do not want these to get in the way of getting good code into the codebase.
|
||||||
|
|
||||||
|
## 🚀 Quick Start
|
||||||
|
|
||||||
|
> **Note:** You can run this repository locally (which is described below) or in a [development container](https://containers.dev/) (which is described in the [.devcontainer folder](https://github.com/hwchase17/Swarms/tree/master/.devcontainer)).
|
||||||
|
|
||||||
|
This project uses [Poetry](https://python-poetry.org/) as a dependency manager. Check out Poetry's [documentation on how to install it](https://python-poetry.org/docs/#installation) on your system before proceeding.
|
||||||
|
|
||||||
|
❗Note: If you use `Conda` or `Pyenv` as your environment / package manager, avoid dependency conflicts by doing the following first:
|
||||||
|
1. *Before installing Poetry*, create and activate a new Conda env (e.g. `conda create -n Swarms python=3.9`)
|
||||||
|
2. Install Poetry (see above)
|
||||||
|
3. Tell Poetry to use the virtualenv python environment (`poetry config virtualenvs.prefer-active-python true`)
|
||||||
|
4. Continue with the following steps.
|
||||||
|
|
||||||
|
To install requirements:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
poetry install -E all
|
||||||
|
```
|
||||||
|
|
||||||
|
This will install all requirements for running the package, examples, linting, formatting, tests, and coverage. Note the `-E all` flag will install all optional dependencies necessary for integration testing.
|
||||||
|
|
||||||
|
❗Note: If you're running Poetry 1.4.1 and receive a `WheelFileValidationError` for `debugpy` during installation, you can try either downgrading to Poetry 1.4.0 or disabling "modern installation" (`poetry config installer.modern-installation false`) and re-install requirements. See [this `debugpy` issue](https://github.com/microsoft/debugpy/issues/1246) for more details.
|
||||||
|
|
||||||
|
Now, you should be able to run the common tasks in the following section. To double check, run `make test`, all tests should pass. If they don't you may need to pip install additional dependencies, such as `numexpr` and `openapi_schema_pydantic`.
|
||||||
|
|
||||||
|
## ✅ Common Tasks
|
||||||
|
|
||||||
|
Type `make` for a list of common tasks.
|
||||||
|
|
||||||
|
### Code Formatting
|
||||||
|
|
||||||
|
Formatting for this project is done via a combination of [Black](https://black.readthedocs.io/en/stable/) and [isort](https://pycqa.github.io/isort/).
|
||||||
|
|
||||||
|
To run formatting for this project:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
make format
|
||||||
|
```
|
||||||
|
|
||||||
|
### Linting
|
||||||
|
|
||||||
|
Linting for this project is done via a combination of [Black](https://black.readthedocs.io/en/stable/), [isort](https://pycqa.github.io/isort/), [flake8](https://flake8.pycqa.org/en/latest/), and [mypy](http://mypy-lang.org/).
|
||||||
|
|
||||||
|
To run linting for this project:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
make lint
|
||||||
|
```
|
||||||
|
|
||||||
|
We recognize linting can be annoying - if you do not want to do it, please contact a project maintainer, and they can help you with it. We do not want this to be a blocker for good code getting contributed.
|
||||||
|
|
||||||
|
### Coverage
|
||||||
|
|
||||||
|
Code coverage (i.e. the amount of code that is covered by unit tests) helps identify areas of the code that are potentially more or less brittle.
|
||||||
|
|
||||||
|
To get a report of current coverage, run the following:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
make coverage
|
||||||
|
```
|
||||||
|
|
||||||
|
### Working with Optional Dependencies
|
||||||
|
|
||||||
|
Swarms relies heavily on optional dependencies to keep the Swarms package lightweight.
|
||||||
|
|
||||||
|
If you're adding a new dependency to Swarms, assume that it will be an optional dependency, and
|
||||||
|
that most users won't have it installed.
|
||||||
|
|
||||||
|
Users that do not have the dependency installed should be able to **import** your code without
|
||||||
|
any side effects (no warnings, no errors, no exceptions).
|
||||||
|
|
||||||
|
To introduce the dependency to the pyproject.toml file correctly, please do the following:
|
||||||
|
|
||||||
|
1. Add the dependency to the main group as an optional dependency
|
||||||
|
```bash
|
||||||
|
poetry add --optional [package_name]
|
||||||
|
```
|
||||||
|
2. Open pyproject.toml and add the dependency to the `extended_testing` extra
|
||||||
|
3. Relock the poetry file to update the extra.
|
||||||
|
```bash
|
||||||
|
poetry lock --no-update
|
||||||
|
```
|
||||||
|
4. Add a unit test that the very least attempts to import the new code. Ideally the unit
|
||||||
|
test makes use of lightweight fixtures to test the logic of the code.
|
||||||
|
5. Please use the `@pytest.mark.requires(package_name)` decorator for any tests that require the dependency.
|
||||||
|
|
||||||
|
### Testing
|
||||||
|
|
||||||
|
See section about optional dependencies.
|
||||||
|
|
||||||
|
#### Unit Tests
|
||||||
|
|
||||||
|
Unit tests cover modular logic that does not require calls to outside APIs.
|
||||||
|
|
||||||
|
To run unit tests:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
make test
|
||||||
|
```
|
||||||
|
|
||||||
|
To run unit tests in Docker:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
make docker_tests
|
||||||
|
```
|
||||||
|
|
||||||
|
If you add new logic, please add a unit test.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
#### Integration Tests
|
||||||
|
|
||||||
|
Integration tests cover logic that requires making calls to outside APIs (often integration with other services).
|
||||||
|
|
||||||
|
**warning** Almost no tests should be integration tests.
|
||||||
|
|
||||||
|
Tests that require making network connections make it difficult for other
|
||||||
|
developers to test the code.
|
||||||
|
|
||||||
|
Instead favor relying on `responses` library and/or mock.patch to mock
|
||||||
|
requests using small fixtures.
|
||||||
|
|
||||||
|
To run integration tests:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
make integration_tests
|
||||||
|
```
|
||||||
|
|
||||||
|
If you add support for a new external API, please add a new integration test.
|
||||||
|
|
||||||
|
### Adding a Jupyter Notebook
|
||||||
|
|
||||||
|
If you are adding a Jupyter notebook example, you'll want to install the optional `dev` dependencies.
|
||||||
|
|
||||||
|
To install dev dependencies:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
poetry install --with dev
|
||||||
|
```
|
||||||
|
|
||||||
|
Launch a notebook:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
poetry run jupyter notebook
|
||||||
|
```
|
||||||
|
|
||||||
|
When you run `poetry install`, the `Swarms` package is installed as editable in the virtualenv, so your new logic can be imported into the notebook.
|
||||||
|
|
||||||
|
## Documentation
|
||||||
|
|
||||||
|
### Contribute Documentation
|
||||||
|
|
||||||
|
Docs are largely autogenerated by [sphinx](https://www.sphinx-doc.org/en/master/) from the code.
|
||||||
|
|
||||||
|
For that reason, we ask that you add good documentation to all classes and methods.
|
||||||
|
|
||||||
|
Similar to linting, we recognize documentation can be annoying. If you do not want to do it, please contact a project maintainer, and they can help you with it. We do not want this to be a blocker for good code getting contributed.
|
||||||
|
|
||||||
|
### Build Documentation Locally
|
||||||
|
|
||||||
|
Before building the documentation, it is always a good idea to clean the build directory:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
make docs_clean
|
||||||
|
```
|
||||||
|
|
||||||
|
Next, you can run the linkchecker to make sure all links are valid:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
make docs_linkcheck
|
||||||
|
```
|
||||||
|
|
||||||
|
Finally, you can build the documentation as outlined below:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
make docs_build
|
||||||
|
```
|
||||||
|
|
||||||
|
## 🏭 Release Process
|
||||||
|
|
||||||
|
As of now, Swarms has an ad hoc release process: releases are cut with high frequency by
|
||||||
|
a developer and published to [PyPI](https://pypi.org/project/Swarms/).
|
||||||
|
|
||||||
|
Swarms follows the [semver](https://semver.org/) versioning standard. However, as pre-1.0 software,
|
||||||
|
even patch releases may contain [non-backwards-compatible changes](https://semver.org/#spec-item-4).
|
||||||
|
|
||||||
|
### 🌟 Recognition
|
||||||
|
|
||||||
|
If your contribution has made its way into a release, we will want to give you credit on Twitter (only if you want though)!
|
||||||
|
If you have a Twitter account you would like us to mention, please let us know in the PR or in another manner.
|
@ -0,0 +1,75 @@
|
|||||||
|
# Swarms Documentation
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
The Swarm module includes the implementation of two classes, `WorkerNode` and `BossNode`, which respectively represent a worker agent and a boss agent. A worker agent is responsible for completing given tasks, while a boss agent is responsible for creating and managing tasks for the worker agent(s).
|
||||||
|
|
||||||
|
## Key Classes
|
||||||
|
|
||||||
|
### WorkerNode
|
||||||
|
```python
|
||||||
|
class WorkerNode:
|
||||||
|
```
|
||||||
|
|
||||||
|
The WorkerNode class represents an autonomous worker agent that can perform a range of tasks.
|
||||||
|
|
||||||
|
__Methods__:
|
||||||
|
|
||||||
|
- `create_agent(ai_name: str, ai_role: str, human_in_the_loop: bool, search_kwargs: dict) -> None`:
|
||||||
|
|
||||||
|
This method creates a new autonomous agent that can complete tasks. The agent utilizes several tools such as search engines, a file writer/reader, and a multi-modal visual tool.
|
||||||
|
The agent's configuration is customizable through the method parameters.
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Example usage
|
||||||
|
worker_node = WorkerNode(llm, tools, vectorstore)
|
||||||
|
worker_node.create_agent('test_agent', 'test_role', False, {})
|
||||||
|
```
|
||||||
|
|
||||||
|
- `run_agent(prompt: str) -> None`:
|
||||||
|
|
||||||
|
This method runs the agent on a given task, defined by the `prompt` parameter.
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Example usage
|
||||||
|
worker_node = WorkerNode(llm, tools, vectorstore)
|
||||||
|
worker_node.create_agent('test_agent', 'test_role', False, {})
|
||||||
|
worker_node.run_agent('Calculate the square root of 144.')
|
||||||
|
```
|
||||||
|
|
||||||
|
### BossNode
|
||||||
|
```python
|
||||||
|
class BossNode:
|
||||||
|
```
|
||||||
|
|
||||||
|
The BossNode class represents a manager agent that can create tasks and control the execution of these tasks.
|
||||||
|
|
||||||
|
__Methods__:
|
||||||
|
|
||||||
|
- `create_task(objective: str) -> dict`:
|
||||||
|
|
||||||
|
This method creates a new task based on the provided `objective`. The created task is a dictionary with the objective as its value.
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Example usage
|
||||||
|
boss_node = BossNode(llm, vectorstore, task_execution_chain, False, 3)
|
||||||
|
task = boss_node.create_task('Find the square root of 144.')
|
||||||
|
```
|
||||||
|
|
||||||
|
- `execute_task(task: dict) -> None`:
|
||||||
|
|
||||||
|
This method triggers the execution of a given task.
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Example usage
|
||||||
|
boss_node = BossNode(llm, vectorstore, task_execution_chain, False, 3)
|
||||||
|
task = boss_node.create_task('Find the square root of 144.')
|
||||||
|
boss_node.execute_task(task)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Note
|
||||||
|
|
||||||
|
Before creating the WorkerNode and BossNode, make sure to initialize the lower level model (llm), tools, and vectorstore which are used as parameters in the constructors of the two classes.
|
||||||
|
|
||||||
|
In addition, the WorkerNode class uses the MultiModalVisualAgentTool which is a custom tool that enables the worker agent to run multi-modal visual tasks. Ensure that this tool is correctly initialized before running the WorkerNode.
|
||||||
|
|
||||||
|
This documentation provides an overview of the main functionalities of the Swarm module. For additional details and advanced functionalities, please review the source code and the accompanying comments.
|
@ -0,0 +1,214 @@
|
|||||||
|
## Swarming Architectures
|
||||||
|
|
||||||
|
Here are three examples of swarming architectures that could be applied in this context.
|
||||||
|
|
||||||
|
1. **Hierarchical Swarms**: In this architecture, a 'lead' agent coordinates the efforts of other agents, distributing tasks based on each agent's unique strengths. The lead agent might be equipped with additional functionality or decision-making capabilities to effectively manage the swarm.
|
||||||
|
|
||||||
|
2. **Collaborative Swarms**: Here, each agent in the swarm works in parallel, potentially on different aspects of a task. They then collectively determine the best output, often through a voting or consensus mechanism.
|
||||||
|
|
||||||
|
3. **Competitive Swarms**: In this setup, multiple agents work on the same task independently. The output from the agent which produces the highest confidence or quality result is then selected. This can often lead to more robust outputs, as the competition drives each agent to perform at its best.
|
||||||
|
|
||||||
|
4. **Multi-Agent Debate**: Here, multiple agents debate a topic. The output from the agent which produces the highest confidence or quality result is then selected. This can lead to more robust outputs, as the competition drives each agent to perform it's best.
|
||||||
|
|
||||||
|
|
||||||
|
# Ideas
|
||||||
|
|
||||||
|
A swarm, particularly in the context of distributed computing, refers to a large number of coordinated agents or nodes that work together to solve a problem. The specific requirements of a swarm might vary depending on the task at hand, but some of the general requirements include:
|
||||||
|
|
||||||
|
1. **Distributed Nature**: The swarm should consist of multiple individual units or nodes, each capable of functioning independently.
|
||||||
|
|
||||||
|
2. **Coordination**: The nodes in the swarm need to coordinate with each other to ensure they're working together effectively. This might involve communication between nodes, or it could be achieved through a central orchestrator.
|
||||||
|
|
||||||
|
3. **Scalability**: A well-designed swarm system should be able to scale up or down as needed, adding or removing nodes based on the task load.
|
||||||
|
|
||||||
|
4. **Resilience**: If a node in the swarm fails, it shouldn't bring down the entire system. Instead, other nodes should be able to pick up the slack.
|
||||||
|
|
||||||
|
5. **Load Balancing**: Tasks should be distributed evenly across the nodes in the swarm to avoid overloading any single node.
|
||||||
|
|
||||||
|
6. **Interoperability**: Each node should be able to interact with others, regardless of differences in underlying hardware or software.
|
||||||
|
|
||||||
|
Integrating these requirements with Large Language Models (LLMs) can be done as follows:
|
||||||
|
|
||||||
|
1. **Distributed Nature**: Each LLM agent can be considered as a node in the swarm. These agents can be distributed across multiple servers or even geographically dispersed data centers.
|
||||||
|
|
||||||
|
2. **Coordination**: An orchestrator can manage the LLM agents, assigning tasks, coordinating responses, and ensuring effective collaboration between agents.
|
||||||
|
|
||||||
|
3. **Scalability**: As the demand for processing power increases or decreases, the number of LLM agents can be adjusted accordingly.
|
||||||
|
|
||||||
|
4. **Resilience**: If an LLM agent goes offline or fails, the orchestrator can assign its tasks to other agents, ensuring the swarm continues functioning smoothly.
|
||||||
|
|
||||||
|
5. **Load Balancing**: The orchestrator can also handle load balancing, ensuring tasks are evenly distributed amongst the LLM agents.
|
||||||
|
|
||||||
|
6. **Interoperability**: By standardizing the input and output formats of the LLM agents, they can effectively communicate and collaborate, regardless of the specific model or configuration of each agent.
|
||||||
|
|
||||||
|
In terms of architecture, the swarm might look something like this:
|
||||||
|
|
||||||
|
```
|
||||||
|
(Orchestrator)
|
||||||
|
/ \
|
||||||
|
Tools + Vector DB -- (LLM Agent)---(Communication Layer) (Communication Layer)---(LLM Agent)-- Tools + Vector DB
|
||||||
|
/ | | \
|
||||||
|
(Task Assignment) (Task Completion) (Task Assignment) (Task Completion)
|
||||||
|
```
|
||||||
|
|
||||||
|
Each LLM agent communicates with the orchestrator through a dedicated communication layer. The orchestrator assigns tasks to each LLM agent, which the agents then complete and return. This setup allows for a high degree of flexibility, scalability, and robustness.
|
||||||
|
|
||||||
|
|
||||||
|
## Communication Layer
|
||||||
|
|
||||||
|
Communication layers play a critical role in distributed systems, enabling interaction between different nodes (agents) and the orchestrator. Here are three potential communication layers for a distributed system, including their strengths and weaknesses:
|
||||||
|
|
||||||
|
1. **Message Queuing Systems (like RabbitMQ, Kafka)**:
|
||||||
|
|
||||||
|
- Strengths: They are highly scalable, reliable, and designed for high-throughput systems. They also ensure delivery of messages and can persist them if necessary. Furthermore, they support various messaging patterns like publish/subscribe, which can be highly beneficial in a distributed system. They also have robust community support.
|
||||||
|
|
||||||
|
- Weaknesses: They can add complexity to the system, including maintenance of the message broker. Moreover, they require careful configuration to perform optimally, and handling failures can sometimes be challenging.
|
||||||
|
|
||||||
|
2. **RESTful APIs**:
|
||||||
|
|
||||||
|
- Strengths: REST is widely adopted, and most programming languages have libraries to easily create RESTful APIs. They leverage standard HTTP(S) protocols and methods and are straightforward to use. Also, they can be stateless, meaning each request contains all the necessary information, enabling scalability.
|
||||||
|
|
||||||
|
- Weaknesses: For real-time applications, REST may not be the best fit due to its synchronous nature. Additionally, handling a large number of API requests can put a strain on the system, causing slowdowns or timeouts.
|
||||||
|
|
||||||
|
3. **gRPC (Google Remote Procedure Call)**:
|
||||||
|
|
||||||
|
- Strengths: gRPC uses Protocol Buffers as its interface definition language, leading to smaller payloads and faster serialization/deserialization compared to JSON (commonly used in RESTful APIs). It supports bidirectional streaming and can use HTTP/2 features, making it excellent for real-time applications.
|
||||||
|
|
||||||
|
- Weaknesses: gRPC is more complex to set up compared to REST. Protocol Buffers' binary format can be more challenging to debug than JSON. It's also not as widely adopted as REST, so tooling and support might be limited in some environments.
|
||||||
|
|
||||||
|
In the context of swarm LLMs, one could consider an **Omni-Vector Embedding Database** for communication. This database could store and manage the high-dimensional vectors produced by each LLM agent.
|
||||||
|
|
||||||
|
- Strengths: This approach would allow for similarity-based lookup and matching of LLM-generated vectors, which can be particularly useful for tasks that involve finding similar outputs or recognizing patterns.
|
||||||
|
|
||||||
|
- Weaknesses: An Omni-Vector Embedding Database might add complexity to the system in terms of setup and maintenance. It might also require significant computational resources, depending on the volume of data being handled and the complexity of the vectors. The handling and transmission of high-dimensional vectors could also pose challenges in terms of network load.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# Technical Analysis Document: Particle Swarm of AI Agents using Ocean Database
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The goal is to create a particle swarm of AI agents using the OpenAI API for the agents and the Ocean database as the communication space, where the embeddings act as particles. The swarm will work collectively to perform tasks and optimize their behavior based on the interaction with the Ocean database.
|
||||||
|
|
||||||
|
## Algorithmic Overview
|
||||||
|
|
||||||
|
1. Initialize the AI agents and the Ocean database.
|
||||||
|
2. Assign tasks to the AI agents.
|
||||||
|
3. AI agents use the OpenAI API to perform tasks and generate embeddings.
|
||||||
|
4. AI agents store their embeddings in the Ocean database.
|
||||||
|
5. AI agents query the Ocean database for relevant embeddings.
|
||||||
|
6. AI agents update their positions based on the retrieved embeddings.
|
||||||
|
7. Evaluate the performance of the swarm and update the agents' behavior accordingly.
|
||||||
|
8. Repeat steps 3-7 until a stopping criterion is met.
|
||||||
|
|
||||||
|
## Python Implementation Logic
|
||||||
|
|
||||||
|
1. **Initialize the AI agents and the Ocean database.**
|
||||||
|
|
||||||
|
```python
|
||||||
|
import openai
|
||||||
|
import oceandb
|
||||||
|
from oceandb.utils.embedding_functions import ImageBindEmbeddingFunction
|
||||||
|
|
||||||
|
# Initialize Ocean database
|
||||||
|
client = oceandb.Client()
|
||||||
|
text_embedding_function = ImageBindEmbeddingFunction(modality="text")
|
||||||
|
collection = client.create_collection("all-my-documents", embedding_function=text_embedding_function)
|
||||||
|
|
||||||
|
# Initialize AI agents
|
||||||
|
agents = initialize_agents(...)
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Assign tasks to the AI agents.**
|
||||||
|
|
||||||
|
```python
|
||||||
|
tasks = assign_tasks_to_agents(agents, ...)
|
||||||
|
```
|
||||||
|
|
||||||
|
3. **AI agents use the OpenAI API to perform tasks and generate embeddings.**
|
||||||
|
|
||||||
|
```python
|
||||||
|
def agent_perform_task(agent, task):
|
||||||
|
# Perform the task using the OpenAI API
|
||||||
|
result = perform_task_with_openai_api(agent, task)
|
||||||
|
# Generate the embedding
|
||||||
|
embedding = generate_embedding(result)
|
||||||
|
return embedding
|
||||||
|
|
||||||
|
embeddings = [agent_perform_task(agent, task) for agent, task in zip(agents, tasks)]
|
||||||
|
```
|
||||||
|
|
||||||
|
4. **AI agents store their embeddings in the Ocean database.**
|
||||||
|
|
||||||
|
```python
|
||||||
|
def store_embeddings_in_database(embeddings, collection):
|
||||||
|
for i, embedding in enumerate(embeddings):
|
||||||
|
document_id = f"agent_{i}"
|
||||||
|
collection.add(documents=[embedding], ids=[document_id])
|
||||||
|
|
||||||
|
store_embeddings_in_database(embeddings, collection)
|
||||||
|
```
|
||||||
|
|
||||||
|
5. **AI agents query the Ocean database for relevant embeddings.**
|
||||||
|
|
||||||
|
```python
|
||||||
|
def query_database_for_embeddings(agent, collection, n_results=1):
|
||||||
|
query_result = collection.query(query_texts=[agent], n_results=n_results)
|
||||||
|
return query_result
|
||||||
|
|
||||||
|
queried_embeddings = [query_database_for_embeddings(agent, collection) for agent in agents]
|
||||||
|
```
|
||||||
|
|
||||||
|
6. **AI agents update their positions based on the retrieved embeddings.**
|
||||||
|
|
||||||
|
```python
|
||||||
|
def update_agent_positions(agents, queried_embeddings):
|
||||||
|
for agent, embedding in zip(agents, queried_embeddings):
|
||||||
|
agent.update_position(embedding)
|
||||||
|
|
||||||
|
update_agent_positions(agents, queried_embeddings)
|
||||||
|
```
|
||||||
|
|
||||||
|
7. **Evaluate the performance of the swarm and update the agents' behavior accordingly.**
|
||||||
|
|
||||||
|
```python
|
||||||
|
def evaluate_swarm_performance(agents, ...):
|
||||||
|
# Evaluate the performance of the swarm
|
||||||
|
performance = compute_performance_metric(agents, ...)
|
||||||
|
return performance
|
||||||
|
|
||||||
|
def update_agent_behavior(agents, performance):
|
||||||
|
# Update agents' behavior based on swarm performance
|
||||||
|
for agent in agents:
|
||||||
|
agent.adjust_behavior(performance)
|
||||||
|
|
||||||
|
performance = evaluate_swarm_performance(agents, ...)
|
||||||
|
update_agent_behavior(agents, performance)
|
||||||
|
```
|
||||||
|
|
||||||
|
8. **Repeat steps 3-7 until a stopping criterion is met.**
|
||||||
|
|
||||||
|
```python
|
||||||
|
while not stopping_criterion_met():
|
||||||
|
# Perform tasks and generate embeddings
|
||||||
|
embeddings = [agent_perform_task(agent, task) for agent, task in zip(agents, tasks)]
|
||||||
|
|
||||||
|
# Store embeddings in the Ocean database
|
||||||
|
store_embeddings_in_database(embeddings, collection)
|
||||||
|
|
||||||
|
# Query the Ocean database for relevant embeddings
|
||||||
|
queried_embeddings = [query_database_for_embeddings(agent, collection) for agent in agents]
|
||||||
|
|
||||||
|
# Update AI agent positions based on the retrieved embeddings
|
||||||
|
update_agent_positions(agents, queried_embeddings)
|
||||||
|
|
||||||
|
# Evaluate the performance of the swarm and update the agents' behavior accordingly
|
||||||
|
performance = evaluate_swarm_performance(agents, ...)
|
||||||
|
update_agent_behavior(agents, performance)
|
||||||
|
```
|
||||||
|
|
||||||
|
This code demonstrates the complete loop to repeat steps 3-7 until a stopping criterion is met. You will need to define the `stopping_criterion_met()` function, which could be based on a predefined number of iterations, a target performance level, or any other condition that indicates that the swarm has reached a desired state.
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -0,0 +1,13 @@
|
|||||||
|
Today, we stand at the verge of a revolution in artificial intelligence and machine learning. Individual models have accomplished incredible feats, achieving unprecedented levels of understanding and generating incredibly human-like text. But this is just the beginning.
|
||||||
|
|
||||||
|
In the future, we should expect more. These models, which we've seen perform so admirably in isolation, should be able to work together, as a team, a swarm. However, this kind of collaborative intelligence doesn't exist today. That's because the technology to seamlessly integrate these models and foster true inter-model collaboration has been missing, until now.
|
||||||
|
|
||||||
|
In attempting to create this swarm, we face numerous challenges, such as developing the necessary infrastructure, ensuring seamless integration between the agents, and overcoming the practical limitations of our current computing capabilities. These are daunting tasks, and many have shied away from them because of the sheer complexity of the problem. But, if we can overcome these challenges, the rewards will be unimaginable, all digital activities will be automated.
|
||||||
|
|
||||||
|
We envision a future where swarms of Language Learning Model (LLM) agents revolutionize fields like customer support, content creation, and research. Imagine an AI system that could work cohesively, understand complex problems, and deliver multi-faceted solutions. We estimate this could lead to a 100-fold improvement in AI effectiveness, and up to a trillion-dollar impact on the global economy.
|
||||||
|
|
||||||
|
The secret to achieving this lies in our open-source approach and the power of the collective. By embracing open-source, we are enabling hundreds of thousands of minds worldwide to contribute to this vision, each bringing unique insights and solutions. Our bug bounty program and automated testing environments will act as catalysts, motivating and rewarding contributors while ensuring the robustness and reliability of our technology.
|
||||||
|
|
||||||
|
At Agora, we believe in the transformative potential of this technology, and we are committed to making it a reality. Our world-class team of researchers, engineers, and AI enthusiasts are singularly focused on this mission. With a proven track record of success, and the tenacity to tackle the most complex problems, we are best positioned to lead this charge.
|
||||||
|
|
||||||
|
We invite you to join us on this exciting journey. Let's come together to create swarms, advance humanity, and redefine what is possible with artificial intelligence. Our future is in our hands. Let's shape it together.
|
@ -0,0 +1,149 @@
|
|||||||
|
# Bounty Program
|
||||||
|
|
||||||
|
Our bounty program is an exciting opportunity for contributors to help us build the future of Swarms. By participating, you can earn rewards while contributing to a project that aims to revolutionize digital activity.
|
||||||
|
|
||||||
|
Here's how it works:
|
||||||
|
|
||||||
|
1. **Check out our Roadmap**: We've shared our roadmap detailing our short and long-term goals. These are the areas where we're seeking contributions.
|
||||||
|
|
||||||
|
2. **Pick a Task**: Choose a task from the roadmap that aligns with your skills and interests. If you're unsure, you can reach out to our team for guidance.
|
||||||
|
|
||||||
|
3. **Get to Work**: Once you've chosen a task, start working on it. Remember, quality is key. We're looking for contributions that truly make a difference.
|
||||||
|
|
||||||
|
4. **Submit your Contribution**: Once your work is complete, submit it for review. We'll evaluate your contribution based on its quality, relevance, and the value it brings to Swarms.
|
||||||
|
|
||||||
|
5. **Earn Rewards**: If your contribution is approved, you'll earn a bounty. The amount of the bounty depends on the complexity of the task, the quality of your work, and the value it brings to Swarms.
|
||||||
|
|
||||||
|
## The Three Phases of Our Bounty Program
|
||||||
|
|
||||||
|
### Phase 1: Building the Foundation
|
||||||
|
In the first phase, our focus is on building the basic infrastructure of Swarms. This includes developing key components like the Swarms class, integrating essential tools, and establishing task completion and evaluation logic. We'll also start developing our testing and evaluation framework during this phase. If you're interested in foundational work and have a knack for building robust, scalable systems, this phase is for you.
|
||||||
|
|
||||||
|
### Phase 2: Enhancing the System
|
||||||
|
In the second phase, we'll focus on enhancing Swarms by integrating more advanced features, improving the system's efficiency, and refining our testing and evaluation framework. This phase involves more complex tasks, so if you enjoy tackling challenging problems and contributing to the development of innovative features, this is the phase for you.
|
||||||
|
|
||||||
|
### Phase 3: Towards Super-Intelligence
|
||||||
|
The third phase of our bounty program is the most exciting - this is where we aim to achieve super-intelligence. In this phase, we'll be working on improving the swarm's capabilities, expanding its skills, and fine-tuning the system based on real-world testing and feedback. If you're excited about the future of AI and want to contribute to a project that could potentially transform the digital world, this is the phase for you.
|
||||||
|
|
||||||
|
Remember, our roadmap is a guide, and we encourage you to bring your own ideas and creativity to the table. We believe that every contribution, no matter how small, can make a difference. So join us on this exciting journey and help us create the future of Swarms.
|
||||||
|
|
||||||
|
**To participate in our bounty program, visit the [Swarms Bounty Program Page](https://swarms.ai/bounty).** Let's build the future together!
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## Bounties for Roadmap Items
|
||||||
|
|
||||||
|
To accelerate the development of Swarms and to encourage more contributors to join our journey towards automating every digital activity in existence, we are announcing a Bounty Program for specific roadmap items. Each bounty will be rewarded based on the complexity and importance of the task. Below are the items available for bounty:
|
||||||
|
|
||||||
|
1. **Multi-Agent Debate Integration**: $2000
|
||||||
|
2. **Meta Prompting Integration**: $1500
|
||||||
|
3. **Swarms Class**: $1500
|
||||||
|
4. **Integration of Additional Tools**: $1000
|
||||||
|
5. **Task Completion and Evaluation Logic**: $2000
|
||||||
|
6. **Ocean Integration**: $2500
|
||||||
|
7. **Improved Communication**: $2000
|
||||||
|
8. **Testing and Evaluation**: $1500
|
||||||
|
9. **Worker Swarm Class**: $2000
|
||||||
|
10. **Documentation**: $500
|
||||||
|
|
||||||
|
For each bounty task, there will be a strict evaluation process to ensure the quality of the contribution. This process includes a thorough review of the code and extensive testing to ensure it meets our standards.
|
||||||
|
|
||||||
|
# 3-Phase Testing Framework
|
||||||
|
|
||||||
|
To ensure the quality and efficiency of the Swarm, we will introduce a 3-phase testing framework which will also serve as our evaluation criteria for each of the bounty tasks.
|
||||||
|
|
||||||
|
## Phase 1: Unit Testing
|
||||||
|
In this phase, individual modules will be tested to ensure that they work correctly in isolation. Unit tests will be designed for all functions and methods, with an emphasis on edge cases.
|
||||||
|
|
||||||
|
## Phase 2: Integration Testing
|
||||||
|
After passing unit tests, we will test the integration of different modules to ensure they work correctly together. This phase will also test the interoperability of the Swarm with external systems and libraries.
|
||||||
|
|
||||||
|
## Phase 3: Benchmarking & Stress Testing
|
||||||
|
In the final phase, we will perform benchmarking and stress tests. We'll push the limits of the Swarm under extreme conditions to ensure it performs well in real-world scenarios. This phase will measure the performance, speed, and scalability of the Swarm under high load conditions.
|
||||||
|
|
||||||
|
By following this 3-phase testing framework, we aim to develop a reliable, high-performing, and scalable Swarm that can automate all digital activities.
|
||||||
|
|
||||||
|
# Reverse Engineering to Reach Phase 3
|
||||||
|
|
||||||
|
To reach the Phase 3 level, we need to reverse engineer the tasks we need to complete. Here's an example of what this might look like:
|
||||||
|
|
||||||
|
1. **Set Clear Expectations**: Define what success looks like for each task. Be clear about the outputs and outcomes we expect. This will guide our testing and development efforts.
|
||||||
|
|
||||||
|
2. **Develop Testing Scenarios**: Create a comprehensive list of testing scenarios that cover both common and edge cases. This will help us ensure that our Swarm can handle a wide range of situations.
|
||||||
|
|
||||||
|
3. **Write Test Cases**: For each scenario, write detailed test cases that outline the exact steps to be followed, the inputs to be used, and the expected outputs.
|
||||||
|
|
||||||
|
4. **Execute the Tests**: Run the test cases on our Swarm, making note of any issues or bugs that arise.
|
||||||
|
|
||||||
|
5. **Iterate and Improve**: Based on the results of our tests, iterate and improve our Swarm. This may involve fixing bugs, optimizing code, or redesigning parts of our system.
|
||||||
|
|
||||||
|
6. **Repeat**: Repeat this process until our Swarm meets our expectations and passes all test cases.
|
||||||
|
|
||||||
|
By following these steps, we will systematically build, test, and improve our Swarm until it reaches the Phase 3 level. This methodical approach will help us ensure that we create a reliable, high-performing, and scalable Swarm that can truly automate all digital activities.
|
||||||
|
|
||||||
|
Let's shape the future of digital automation together!
|
||||||
|
|
||||||
|
|
||||||
|
--------------------
|
||||||
|
# Super-Intelligence Roadmap
|
||||||
|
|
||||||
|
Creating a Super-Intelligent Swarm involves three main phases, where each phase has multiple sub-stages, each of which will require rigorous testing and evaluation to ensure progress towards super-intelligence.
|
||||||
|
|
||||||
|
## Phase 1: Narrow Intelligence
|
||||||
|
|
||||||
|
In this phase, the goal is to achieve high performance in specific tasks. These tasks will be predefined and the swarm will be trained and tested on these tasks.
|
||||||
|
|
||||||
|
1. **Single Task Mastery**: Focus on mastering one task at a time. This can range from simple tasks like image recognition to complex tasks like natural language processing.
|
||||||
|
|
||||||
|
2. **Task Switching**: Train the swarm to switch between different tasks effectively. This includes being able to stop one task and start another one without any loss in performance.
|
||||||
|
|
||||||
|
3. **Multi-tasking**: The swarm should be capable of performing multiple tasks simultaneously without any degradation in performance.
|
||||||
|
|
||||||
|
## Phase 2: General Intelligence
|
||||||
|
|
||||||
|
In this phase, the swarm will be trained to handle a variety of tasks that were not part of the original training set.
|
||||||
|
|
||||||
|
1. **Transfer Learning**: The swarm should be able to transfer knowledge learned in one context to another context. This means being able to apply knowledge learned in one task to a different but related task.
|
||||||
|
|
||||||
|
2. **Adaptive Learning**: The swarm should be capable of adapting its learning strategies based on the task at hand. This includes being able to adjust its learning rate, exploration vs exploitation balance, etc.
|
||||||
|
|
||||||
|
3. **Self-Learning**: The swarm should be able to learn new tasks on its own without any external guidance. This includes being able to understand the task requirements, find relevant information, learn the task, and evaluate its performance.
|
||||||
|
|
||||||
|
## Phase 3: Super Intelligence
|
||||||
|
|
||||||
|
In this phase, the swarm will surpass human-level performance in most economically valuable work. This involves the swarm being able to solve complex real-world problems, make accurate predictions, and generate innovative solutions.
|
||||||
|
|
||||||
|
1. **Complex Problem Solving**: The swarm should be able to solve complex real-world problems. This includes being able to understand the problem, identify relevant information, generate solutions, evaluate the solutions, and implement the best solution.
|
||||||
|
|
||||||
|
2. **Predictive Abilities**: The swarm should be able to make accurate predictions about future events based on past data. This includes being able to understand the data, identify relevant patterns, make accurate predictions, and evaluate the accuracy of its predictions.
|
||||||
|
|
||||||
|
3. **Innovation**: The swarm should be able to generate innovative solutions to problems. This includes being able to think creatively, generate novel ideas, evaluate the ideas, and implement the best idea.
|
||||||
|
|
||||||
|
4. **Self-improvement**: The swarm should be capable of improving its own capabilities. This includes being able to identify areas of weakness, find ways to improve, and implement the improvements.
|
||||||
|
|
||||||
|
5. **Understanding**: The swarm should be able to understand complex concepts, make inferences, and draw conclusions. This includes being able to understand natural language, reason logically, and make sound judgments.
|
||||||
|
|
||||||
|
Each of these stages will require extensive testing and evaluation to ensure progress towards super-intelligence.
|
||||||
|
|
||||||
|
# Reverse-Engineering Super-Intelligence
|
||||||
|
|
||||||
|
To reach the Phase 3 level of super-intelligence, we need to reverse engineer the tasks that need to be completed. Here's an outline of what this might look like:
|
||||||
|
|
||||||
|
1. **Setting Success Metrics**: For each stage, define clear success metrics. These metrics should be quantitative and measurable, and they should align with the objectives of the stage.
|
||||||
|
|
||||||
|
2. **Identifying Prerequisites**: Determine what needs to be in place before each stage can begin. This could include certain capabilities, resources, or technologies.
|
||||||
|
|
||||||
|
3. **Developing Training Programs**: For each stage, develop a comprehensive training program. This should include a variety of tasks that will challenge the swarm and push it to
|
||||||
|
|
||||||
|
develop the necessary capabilities.
|
||||||
|
|
||||||
|
4. **Creating Testing Protocols**: Develop rigorous testing protocols for each stage. These protocols should test all aspects of the swarm's performance and they should be designed to push the swarm to its limits.
|
||||||
|
|
||||||
|
5. **Iterating and Improving**: Based on the results of the tests, iterate and improve the swarm. This could involve adjusting the training program, modifying the swarm's architecture, or tweaking its learning algorithms.
|
||||||
|
|
||||||
|
6. **Moving to the Next Stage**: Once the swarm has met the success metrics for a stage, it can move on to the next stage. This process continues until the swarm has reached the level of super-intelligence.
|
||||||
|
|
||||||
|
This process will require a significant amount of time, resources, and effort. However, by following this structured approach, we can systematically guide the swarm towards super-intelligence.
|
||||||
|
|
@ -0,0 +1,91 @@
|
|||||||
|
Jeff Bezos, the founder of Amazon.com, is known for his customer-centric approach and long-term strategic thinking. Leveraging his methodology, here are five ways you could monetize the Swarms framework:
|
||||||
|
|
||||||
|
1. **Platform as a Service (PaaS):** Create a cloud-based platform that allows users to build, run, and manage applications without the complexity of maintaining the infrastructure. You could charge users a subscription fee for access to the platform and provide different pricing tiers based on usage levels. This could be an attractive solution for businesses that do not have the capacity to build or maintain their own swarm intelligence solutions.
|
||||||
|
|
||||||
|
2. **Professional Services:** Offer consultancy and implementation services to businesses looking to utilize the Swarm technology. This could include assisting with integration into existing systems, offering custom development services, or helping customers to build specific solutions using the framework.
|
||||||
|
|
||||||
|
3. **Education and Training:** Create a certification program for developers or companies looking to become proficient with the Swarms framework. This could be sold as standalone courses, or bundled with other services.
|
||||||
|
|
||||||
|
4. **Managed Services:** Some companies may prefer to outsource the management of their Swarm-based systems. A managed services solution could take care of all the technical aspects, from hosting the solution to ensuring it runs smoothly, allowing the customer to focus on their core business.
|
||||||
|
|
||||||
|
5. **Data Analysis and Insights:** Swarm intelligence can generate valuable data and insights. By anonymizing and aggregating this data, you could provide industry reports, trend analysis, and other valuable insights to businesses.
|
||||||
|
|
||||||
|
As for the type of platform, Swarms can be offered as a cloud-based solution given its scalability and flexibility. This would also allow you to apply a SaaS/PaaS type monetization model, which provides recurring revenue.
|
||||||
|
|
||||||
|
Potential customers could range from small to large enterprises in various sectors such as logistics, eCommerce, finance, and technology, who are interested in leveraging artificial intelligence and machine learning for complex problem solving, optimization, and decision-making.
|
||||||
|
|
||||||
|
**Product Brief Monetization Strategy:**
|
||||||
|
|
||||||
|
Product Name: Swarms.AI Platform
|
||||||
|
|
||||||
|
Product Description: A cloud-based AI and ML platform harnessing the power of swarm intelligence.
|
||||||
|
|
||||||
|
1. **Platform as a Service (PaaS):** Offer tiered subscription plans (Basic, Premium, Enterprise) to accommodate different usage levels and business sizes.
|
||||||
|
|
||||||
|
2. **Professional Services:** Offer consultancy and custom development services to tailor the Swarms solution to the specific needs of the business.
|
||||||
|
|
||||||
|
3. **Education and Training:** Launch an online Swarms.AI Academy with courses and certifications for developers and businesses.
|
||||||
|
|
||||||
|
4. **Managed Services:** Provide a premium, fully-managed service offering that includes hosting, maintenance, and 24/7 support.
|
||||||
|
|
||||||
|
5. **Data Analysis and Insights:** Offer industry reports and customized insights generated from aggregated and anonymized Swarm data.
|
||||||
|
|
||||||
|
Potential Customers: Enterprises in sectors such as logistics, eCommerce, finance, and technology. This can be sold globally, provided there's an internet connection.
|
||||||
|
|
||||||
|
Marketing Channels: Online marketing (SEO, Content Marketing, Social Media), Partnerships with tech companies, Direct Sales to Enterprises.
|
||||||
|
|
||||||
|
This strategy is designed to provide multiple revenue streams, while ensuring the Swarms.AI platform is accessible and useful to a range of potential customers.
|
||||||
|
|
||||||
|
1. **AI Solution as a Service:** By offering the Swarms framework as a service, businesses can access and utilize the power of multiple LLM agents without the need to maintain the infrastructure themselves. Subscription can be tiered based on usage and additional features.
|
||||||
|
|
||||||
|
2. **Integration and Custom Development:** Offer integration services to businesses wanting to incorporate the Swarms framework into their existing systems. Also, you could provide custom development for businesses with specific needs not met by the standard framework.
|
||||||
|
|
||||||
|
3. **Training and Certification:** Develop an educational platform offering courses, webinars, and certifications on using the Swarms framework. This can serve both developers seeking to broaden their skills and businesses aiming to train their in-house teams.
|
||||||
|
|
||||||
|
4. **Managed Swarms Solutions:** For businesses that prefer to outsource their AI needs, provide a complete solution which includes the development, maintenance, and continuous improvement of swarms-based applications.
|
||||||
|
|
||||||
|
5. **Data Analytics Services:** Leveraging the aggregated insights from the AI swarms, you could offer data analytics services. Businesses can use these insights to make informed decisions and predictions.
|
||||||
|
|
||||||
|
**Type of Platform:**
|
||||||
|
|
||||||
|
Cloud-based platform or Software as a Service (SaaS) will be a suitable model. It offers accessibility, scalability, and ease of updates.
|
||||||
|
|
||||||
|
**Target Customers:**
|
||||||
|
|
||||||
|
The technology can be beneficial for businesses across sectors like eCommerce, technology, logistics, finance, healthcare, and education, among others.
|
||||||
|
|
||||||
|
**Product Brief Monetization Strategy:**
|
||||||
|
|
||||||
|
Product Name: Swarms.AI
|
||||||
|
|
||||||
|
1. **AI Solution as a Service:** Offer different tiered subscriptions (Standard, Premium, and Enterprise) each with varying levels of usage and features.
|
||||||
|
|
||||||
|
2. **Integration and Custom Development:** Offer custom development and integration services, priced based on the scope and complexity of the project.
|
||||||
|
|
||||||
|
3. **Training and Certification:** Launch the Swarms.AI Academy with courses and certifications, available for a fee.
|
||||||
|
|
||||||
|
4. **Managed Swarms Solutions:** Offer fully managed solutions tailored to business needs, priced based on scope and service level agreements.
|
||||||
|
|
||||||
|
5. **Data Analytics Services:** Provide insightful reports and data analyses, which can be purchased on a one-off basis or through a subscription.
|
||||||
|
|
||||||
|
By offering a variety of services and payment models, Swarms.AI will be able to cater to a diverse range of business needs, from small start-ups to large enterprises. Marketing channels would include digital marketing, partnerships with technology companies, presence in tech events, and direct sales to targeted industries.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# Roadmap
|
||||||
|
|
||||||
|
* Create a landing page for swarms apac.ai/product/swarms
|
||||||
|
|
||||||
|
* Create Hosted Swarms API for anybody to just use without need for mega gpu infra, charge usage based pricing. Prerequisites for success => Swarms has to be extremely reliable + we need world class documentation and many daily users => how do we get many daily users? We provide a seamless and fluid experience, how do we create a seamless and fluid experience? We write good code that is modular, provides feedback to the user in times of distress, and ultimately accomplishes the user's tasks.
|
||||||
|
|
||||||
|
* Hosted consumer and enterprise subscription as a service on The Domain, where users can interact with 1000s of APIs and ingest 1000s of different data streams.
|
||||||
|
|
||||||
|
* Hosted dedicated capacity deals with mega enterprises on automating many operations with Swarms for monthly subscription 300,000+$
|
||||||
|
|
||||||
|
* Partnerships with enterprises, massive contracts with performance based fee
|
||||||
|
|
||||||
|
* Have discord bot and or slack bot with users personal data, charge subscription + browser extension
|
||||||
|
|
||||||
|
* each user gets a dedicated ocean instance of all their data so the swarm can query it as needed.
|
||||||
|
|
||||||
|
*
|
@ -0,0 +1,201 @@
|
|||||||
|
Apache License
|
||||||
|
Version 2.0, January 2004
|
||||||
|
http://www.apache.org/licenses/
|
||||||
|
|
||||||
|
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||||
|
|
||||||
|
1. Definitions.
|
||||||
|
|
||||||
|
"License" shall mean the terms and conditions for use, reproduction,
|
||||||
|
and distribution as defined by Sections 1 through 9 of this document.
|
||||||
|
|
||||||
|
"Licensor" shall mean the copyright owner or entity authorized by
|
||||||
|
the copyright owner that is granting the License.
|
||||||
|
|
||||||
|
"Legal Entity" shall mean the union of the acting entity and all
|
||||||
|
other entities that control, are controlled by, or are under common
|
||||||
|
control with that entity. For the purposes of this definition,
|
||||||
|
"control" means (i) the power, direct or indirect, to cause the
|
||||||
|
direction or management of such entity, whether by contract or
|
||||||
|
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||||
|
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||||
|
|
||||||
|
"You" (or "Your") shall mean an individual or Legal Entity
|
||||||
|
exercising permissions granted by this License.
|
||||||
|
|
||||||
|
"Source" form shall mean the preferred form for making modifications,
|
||||||
|
including but not limited to software source code, documentation
|
||||||
|
source, and configuration files.
|
||||||
|
|
||||||
|
"Object" form shall mean any form resulting from mechanical
|
||||||
|
transformation or translation of a Source form, including but
|
||||||
|
not limited to compiled object code, generated documentation,
|
||||||
|
and conversions to other media types.
|
||||||
|
|
||||||
|
"Work" shall mean the work of authorship, whether in Source or
|
||||||
|
Object form, made available under the License, as indicated by a
|
||||||
|
copyright notice that is included in or attached to the work
|
||||||
|
(an example is provided in the Appendix below).
|
||||||
|
|
||||||
|
"Derivative Works" shall mean any work, whether in Source or Object
|
||||||
|
form, that is based on (or derived from) the Work and for which the
|
||||||
|
editorial revisions, annotations, elaborations, or other modifications
|
||||||
|
represent, as a whole, an original work of authorship. For the purposes
|
||||||
|
of this License, Derivative Works shall not include works that remain
|
||||||
|
separable from, or merely link (or bind by name) to the interfaces of,
|
||||||
|
the Work and Derivative Works thereof.
|
||||||
|
|
||||||
|
"Contribution" shall mean any work of authorship, including
|
||||||
|
the original version of the Work and any modifications or additions
|
||||||
|
to that Work or Derivative Works thereof, that is intentionally
|
||||||
|
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||||
|
or by an individual or Legal Entity authorized to submit on behalf of
|
||||||
|
the copyright owner. For the purposes of this definition, "submitted"
|
||||||
|
means any form of electronic, verbal, or written communication sent
|
||||||
|
to the Licensor or its representatives, including but not limited to
|
||||||
|
communication on electronic mailing lists, source code control systems,
|
||||||
|
and issue tracking systems that are managed by, or on behalf of, the
|
||||||
|
Licensor for the purpose of discussing and improving the Work, but
|
||||||
|
excluding communication that is conspicuously marked or otherwise
|
||||||
|
designated in writing by the copyright owner as "Not a Contribution."
|
||||||
|
|
||||||
|
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||||
|
on behalf of whom a Contribution has been received by Licensor and
|
||||||
|
subsequently incorporated within the Work.
|
||||||
|
|
||||||
|
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
copyright license to reproduce, prepare Derivative Works of,
|
||||||
|
publicly display, publicly perform, sublicense, and distribute the
|
||||||
|
Work and such Derivative Works in Source or Object form.
|
||||||
|
|
||||||
|
3. Grant of Patent License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
(except as stated in this section) patent license to make, have made,
|
||||||
|
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||||
|
where such license applies only to those patent claims licensable
|
||||||
|
by such Contributor that are necessarily infringed by their
|
||||||
|
Contribution(s) alone or by combination of their Contribution(s)
|
||||||
|
with the Work to which such Contribution(s) was submitted. If You
|
||||||
|
institute patent litigation against any entity (including a
|
||||||
|
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||||
|
or a Contribution incorporated within the Work constitutes direct
|
||||||
|
or contributory patent infringement, then any patent licenses
|
||||||
|
granted to You under this License for that Work shall terminate
|
||||||
|
as of the date such litigation is filed.
|
||||||
|
|
||||||
|
4. Redistribution. You may reproduce and distribute copies of the
|
||||||
|
Work or Derivative Works thereof in any medium, with or without
|
||||||
|
modifications, and in Source or Object form, provided that You
|
||||||
|
meet the following conditions:
|
||||||
|
|
||||||
|
(a) You must give any other recipients of the Work or
|
||||||
|
Derivative Works a copy of this License; and
|
||||||
|
|
||||||
|
(b) You must cause any modified files to carry prominent notices
|
||||||
|
stating that You changed the files; and
|
||||||
|
|
||||||
|
(c) You must retain, in the Source form of any Derivative Works
|
||||||
|
that You distribute, all copyright, patent, trademark, and
|
||||||
|
attribution notices from the Source form of the Work,
|
||||||
|
excluding those notices that do not pertain to any part of
|
||||||
|
the Derivative Works; and
|
||||||
|
|
||||||
|
(d) If the Work includes a "NOTICE" text file as part of its
|
||||||
|
distribution, then any Derivative Works that You distribute must
|
||||||
|
include a readable copy of the attribution notices contained
|
||||||
|
within such NOTICE file, excluding those notices that do not
|
||||||
|
pertain to any part of the Derivative Works, in at least one
|
||||||
|
of the following places: within a NOTICE text file distributed
|
||||||
|
as part of the Derivative Works; within the Source form or
|
||||||
|
documentation, if provided along with the Derivative Works; or,
|
||||||
|
within a display generated by the Derivative Works, if and
|
||||||
|
wherever such third-party notices normally appear. The contents
|
||||||
|
of the NOTICE file are for informational purposes only and
|
||||||
|
do not modify the License. You may add Your own attribution
|
||||||
|
notices within Derivative Works that You distribute, alongside
|
||||||
|
or as an addendum to the NOTICE text from the Work, provided
|
||||||
|
that such additional attribution notices cannot be construed
|
||||||
|
as modifying the License.
|
||||||
|
|
||||||
|
You may add Your own copyright statement to Your modifications and
|
||||||
|
may provide additional or different license terms and conditions
|
||||||
|
for use, reproduction, or distribution of Your modifications, or
|
||||||
|
for any such Derivative Works as a whole, provided Your use,
|
||||||
|
reproduction, and distribution of the Work otherwise complies with
|
||||||
|
the conditions stated in this License.
|
||||||
|
|
||||||
|
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||||
|
any Contribution intentionally submitted for inclusion in the Work
|
||||||
|
by You to the Licensor shall be under the terms and conditions of
|
||||||
|
this License, without any additional terms or conditions.
|
||||||
|
Notwithstanding the above, nothing herein shall supersede or modify
|
||||||
|
the terms of any separate license agreement you may have executed
|
||||||
|
with Licensor regarding such Contributions.
|
||||||
|
|
||||||
|
6. Trademarks. This License does not grant permission to use the trade
|
||||||
|
names, trademarks, service marks, or product names of the Licensor,
|
||||||
|
except as required for reasonable and customary use in describing the
|
||||||
|
origin of the Work and reproducing the content of the NOTICE file.
|
||||||
|
|
||||||
|
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||||
|
agreed to in writing, Licensor provides the Work (and each
|
||||||
|
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
implied, including, without limitation, any warranties or conditions
|
||||||
|
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||||
|
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||||
|
appropriateness of using or redistributing the Work and assume any
|
||||||
|
risks associated with Your exercise of permissions under this License.
|
||||||
|
|
||||||
|
8. Limitation of Liability. In no event and under no legal theory,
|
||||||
|
whether in tort (including negligence), contract, or otherwise,
|
||||||
|
unless required by applicable law (such as deliberate and grossly
|
||||||
|
negligent acts) or agreed to in writing, shall any Contributor be
|
||||||
|
liable to You for damages, including any direct, indirect, special,
|
||||||
|
incidental, or consequential damages of any character arising as a
|
||||||
|
result of this License or out of the use or inability to use the
|
||||||
|
Work (including but not limited to damages for loss of goodwill,
|
||||||
|
work stoppage, computer failure or malfunction, or any and all
|
||||||
|
other commercial damages or losses), even if such Contributor
|
||||||
|
has been advised of the possibility of such damages.
|
||||||
|
|
||||||
|
9. Accepting Warranty or Additional Liability. While redistributing
|
||||||
|
the Work or Derivative Works thereof, You may choose to offer,
|
||||||
|
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||||
|
or other liability obligations and/or rights consistent with this
|
||||||
|
License. However, in accepting such obligations, You may act only
|
||||||
|
on Your own behalf and on Your sole responsibility, not on behalf
|
||||||
|
of any other Contributor, and only if You agree to indemnify,
|
||||||
|
defend, and hold each Contributor harmless for any liability
|
||||||
|
incurred by, or claims asserted against, such Contributor by reason
|
||||||
|
of your accepting any such warranty or additional liability.
|
||||||
|
|
||||||
|
END OF TERMS AND CONDITIONS
|
||||||
|
|
||||||
|
APPENDIX: How to apply the Apache License to your work.
|
||||||
|
|
||||||
|
To apply the Apache License to your work, attach the following
|
||||||
|
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||||
|
replaced with your own identifying information. (Don't include
|
||||||
|
the brackets!) The text should be enclosed in the appropriate
|
||||||
|
comment syntax for the file format. We also recommend that a
|
||||||
|
file or class name and description of purpose be included on the
|
||||||
|
same "printed page" as the copyright notice for easier
|
||||||
|
identification within third-party archives.
|
||||||
|
|
||||||
|
Copyright [yyyy] [name of copyright owner]
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
@ -0,0 +1,26 @@
|
|||||||
|
<!-- Thank you for contributing to Swarms!
|
||||||
|
|
||||||
|
Replace this comment with:
|
||||||
|
- Description: a description of the change,
|
||||||
|
- Issue: the issue # it fixes (if applicable),
|
||||||
|
- Dependencies: any dependencies required for this change,
|
||||||
|
- Tag maintainer: for a quicker response, tag the relevant maintainer (see below),
|
||||||
|
- Twitter handle: we announce bigger features on Twitter. If your PR gets announced and you'd like a mention, we'll gladly shout you out!
|
||||||
|
|
||||||
|
If you're adding a new integration, please include:
|
||||||
|
1. a test for the integration, preferably unit tests that do not rely on network access,
|
||||||
|
2. an example notebook showing its use.
|
||||||
|
|
||||||
|
Maintainer responsibilities:
|
||||||
|
- General / Misc / if you don't know who to tag: kye@apac.ai
|
||||||
|
- DataLoaders / VectorStores / Retrievers: kye@apac.ai
|
||||||
|
- Models / Prompts: kye@apac.ai
|
||||||
|
- Memory: kye@apac.ai
|
||||||
|
- Agents / Tools / Toolkits: kye@apac.ai
|
||||||
|
- Tracing / Callbacks: kye@apac.ai
|
||||||
|
- Async: kye@apac.ai
|
||||||
|
|
||||||
|
If no one reviews your PR within a few days, feel free to kye@apac.ai
|
||||||
|
|
||||||
|
See contribution guidelines for more information on how to write/run tests, lint, etc: https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md
|
||||||
|
-->
|
@ -0,0 +1,61 @@
|
|||||||
|
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
from fastapi.templating import Jinja2Templates
|
||||||
|
|
||||||
|
from swarms.agents.workers.agents import AgentManager
|
||||||
|
from swarms.utils.utils import BaseHandler, FileHandler, FileType, StaticUploader, CsvToDataframe
|
||||||
|
|
||||||
|
from swarms.tools.main import BaseToolSet, ExitConversation, RequestsGet, CodeEditor, Terminal
|
||||||
|
|
||||||
|
from env import settings
|
||||||
|
|
||||||
|
|
||||||
|
BASE_DIR = Path(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
os.chdir(BASE_DIR / os.getenv["PLAYGROUND_DIR"])
|
||||||
|
|
||||||
|
|
||||||
|
toolsets: List[BaseToolSet] = [
|
||||||
|
Terminal(),
|
||||||
|
CodeEditor(),
|
||||||
|
RequestsGet(),
|
||||||
|
ExitConversation(),
|
||||||
|
]
|
||||||
|
handlers: Dict[FileType, BaseHandler] = {FileType.DATAFRAME: CsvToDataframe()}
|
||||||
|
|
||||||
|
if os.getenv["USE_GPU"]:
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from swarms.tools.main import ImageCaptioning
|
||||||
|
from swarms.tools.main import (
|
||||||
|
ImageEditing,
|
||||||
|
InstructPix2Pix,
|
||||||
|
Text2Image,
|
||||||
|
VisualQuestionAnswering,
|
||||||
|
)
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
toolsets.extend(
|
||||||
|
[
|
||||||
|
Text2Image("cuda"),
|
||||||
|
ImageEditing("cuda"),
|
||||||
|
InstructPix2Pix("cuda"),
|
||||||
|
VisualQuestionAnswering("cuda"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
handlers[FileType.IMAGE] = ImageCaptioning("cuda")
|
||||||
|
|
||||||
|
agent_manager = AgentManager.create(toolsets=toolsets)
|
||||||
|
|
||||||
|
file_handler = FileHandler(handlers=handlers, path=BASE_DIR)
|
||||||
|
|
||||||
|
templates = Jinja2Templates(directory=BASE_DIR / "api" / "templates")
|
||||||
|
|
||||||
|
uploader = StaticUploader.from_settings(
|
||||||
|
settings, path=BASE_DIR / "static", endpoint="static"
|
||||||
|
)
|
||||||
|
|
||||||
|
reload_dirs = [BASE_DIR / "swarms", BASE_DIR / "api"]
|
@ -0,0 +1,130 @@
|
|||||||
|
import os
|
||||||
|
import re
|
||||||
|
from multiprocessing import Process
|
||||||
|
from tempfile import NamedTemporaryFile
|
||||||
|
|
||||||
|
from typing import List, TypedDict
|
||||||
|
import uvicorn
|
||||||
|
from fastapi import FastAPI, Request, UploadFile
|
||||||
|
from fastapi.responses import HTMLResponse
|
||||||
|
|
||||||
|
from fastapi.staticfiles import StaticFiles
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from api.container import agent_manager, file_handler, reload_dirs, templates, uploader
|
||||||
|
from api.worker import get_task_result, start_worker, task_execute
|
||||||
|
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
app.mount("/static", StaticFiles(directory=uploader.path), name="static")
|
||||||
|
|
||||||
|
|
||||||
|
class ExecuteRequest(BaseModel):
|
||||||
|
session: str
|
||||||
|
prompt: str
|
||||||
|
files: List[str]
|
||||||
|
|
||||||
|
|
||||||
|
class ExecuteResponse(TypedDict):
|
||||||
|
answer: str
|
||||||
|
files: List[str]
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/", response_class=HTMLResponse)
|
||||||
|
async def index(request: Request):
|
||||||
|
return templates.TemplateResponse("index.html", {"request": request})
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/dashboard", response_class=HTMLResponse)
|
||||||
|
async def dashboard(request: Request):
|
||||||
|
return templates.TemplateResponse("dashboard.html", {"request": request})
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/upload")
|
||||||
|
async def create_upload_file(files: List[UploadFile]):
|
||||||
|
urls = []
|
||||||
|
for file in files:
|
||||||
|
extension = "." + file.filename.split(".")[-1]
|
||||||
|
with NamedTemporaryFile(suffix=extension) as tmp_file:
|
||||||
|
tmp_file.write(file.file.read())
|
||||||
|
tmp_file.flush()
|
||||||
|
urls.append(uploader.upload(tmp_file.name))
|
||||||
|
return {"urls": urls}
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/api/execute")
|
||||||
|
async def execute(request: ExecuteRequest) -> ExecuteResponse:
|
||||||
|
query = request.prompt
|
||||||
|
files = request.files
|
||||||
|
session = request.session
|
||||||
|
|
||||||
|
executor = agent_manager.create_executor(session)
|
||||||
|
|
||||||
|
promptedQuery = "\n".join([file_handler.handle(file) for file in files])
|
||||||
|
promptedQuery += query
|
||||||
|
|
||||||
|
try:
|
||||||
|
res = executor({"input": promptedQuery})
|
||||||
|
except Exception as e:
|
||||||
|
return {"answer": str(e), "files": []}
|
||||||
|
|
||||||
|
files = re.findall(r"\[file://\S*\]", res["output"])
|
||||||
|
files = [file[1:-1].split("file://")[1] for file in files]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"answer": res["output"],
|
||||||
|
"files": [uploader.upload(file) for file in files],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/api/execute/async")
|
||||||
|
async def execute_async(request: ExecuteRequest):
|
||||||
|
query = request.prompt
|
||||||
|
files = request.files
|
||||||
|
session = request.session
|
||||||
|
|
||||||
|
promptedQuery = "\n".join([file_handler.handle(file) for file in files])
|
||||||
|
promptedQuery += query
|
||||||
|
|
||||||
|
execution = task_execute.delay(session, promptedQuery)
|
||||||
|
return {"id": execution.id}
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/api/execute/async/{execution_id}")
|
||||||
|
async def execute_async(execution_id: str):
|
||||||
|
execution = get_task_result(execution_id)
|
||||||
|
|
||||||
|
result = {}
|
||||||
|
if execution.status == "SUCCESS" and execution.result:
|
||||||
|
output = execution.result.get("output", "")
|
||||||
|
files = re.findall(r"\[file://\S*\]", output)
|
||||||
|
files = [file[1:-1].split("file://")[1] for file in files]
|
||||||
|
result = {
|
||||||
|
"answer": output,
|
||||||
|
"files": [uploader.upload(file) for file in files],
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": execution.status,
|
||||||
|
"info": execution.info,
|
||||||
|
"result": result,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def serve():
|
||||||
|
p = Process(target=start_worker, args=[])
|
||||||
|
p.start()
|
||||||
|
uvicorn.run("api.main:app", host="0.0.0.0", port=os.getenv["EVAL_PORT"])
|
||||||
|
|
||||||
|
|
||||||
|
def dev():
|
||||||
|
p = Process(target=start_worker, args=[])
|
||||||
|
p.start()
|
||||||
|
uvicorn.run(
|
||||||
|
"api.main:app",
|
||||||
|
host="0.0.0.0",
|
||||||
|
port=os.getenv["EVAL_PORT"],
|
||||||
|
reload=True,
|
||||||
|
reload_dirs=reload_dirs,
|
||||||
|
)
|
@ -0,0 +1,46 @@
|
|||||||
|
import os
|
||||||
|
from celery import Celery
|
||||||
|
from celery.result import AsyncResult
|
||||||
|
|
||||||
|
from api.container import agent_manager
|
||||||
|
# from env import settings
|
||||||
|
|
||||||
|
celery_broker = os.environ["CELERY_BROKER_URL"]
|
||||||
|
|
||||||
|
|
||||||
|
celery_app = Celery(__name__)
|
||||||
|
celery_app.conf.broker_url = celery_broker
|
||||||
|
celery_app.conf.result_backend = celery_broker
|
||||||
|
celery_app.conf.update(
|
||||||
|
task_track_started=True,
|
||||||
|
task_serializer="json",
|
||||||
|
accept_content=["json"], # Ignore other content
|
||||||
|
result_serializer="json",
|
||||||
|
enable_utc=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@celery_app.task(name="task_execute", bind=True)
|
||||||
|
def task_execute(self, session: str, prompt: str):
|
||||||
|
executor = agent_manager.create_executor(session, self)
|
||||||
|
response = executor({"input": prompt})
|
||||||
|
result = {"output": response["output"]}
|
||||||
|
|
||||||
|
previous = AsyncResult(self.request.id)
|
||||||
|
if previous and previous.info:
|
||||||
|
result.update(previous.info)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def get_task_result(task_id):
|
||||||
|
return AsyncResult(task_id)
|
||||||
|
|
||||||
|
|
||||||
|
def start_worker():
|
||||||
|
celery_app.worker_main(
|
||||||
|
[
|
||||||
|
"worker",
|
||||||
|
"--loglevel=INFO",
|
||||||
|
]
|
||||||
|
)
|
@ -0,0 +1,25 @@
|
|||||||
|
from swarms import Swarms
|
||||||
|
|
||||||
|
|
||||||
|
# Retrieve your API key from the environment or replace with your actual key
|
||||||
|
api_key = "sksdsds"
|
||||||
|
|
||||||
|
# Initialize Swarms with your API key
|
||||||
|
swarm = Swarms(openai_api_key=api_key)
|
||||||
|
|
||||||
|
# Define an objective
|
||||||
|
objective = """
|
||||||
|
Please make a web GUI for using HTTP API server.
|
||||||
|
The name of it is Swarms.
|
||||||
|
You can check the server code at ./main.py.
|
||||||
|
The server is served on localhost:8000.
|
||||||
|
Users should be able to write text input as 'query' and url array as 'files', and check the response.
|
||||||
|
Users input form should be delivered in JSON format.
|
||||||
|
I want it to have neumorphism-style. Serve it on port 4500.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Run Swarms
|
||||||
|
task = swarm.run_swarms(objective)
|
||||||
|
|
||||||
|
print(task)
|
@ -0,0 +1,50 @@
|
|||||||
|
# This is a basic Dockerfile and might need to be adjusted according to your specific application's needs. Please replace the placeholders for environment variables with your actual keys. Also, remember not to expose sensitive data like API keys in your Dockerfile or any version control systems.
|
||||||
|
|
||||||
|
# When building and running this Docker container, be sure to allocate enough resources (especially GPU memory) for your chosen visual foundation model if running on a machine with an NVIDIA GPU. You may need to use nvidia-docker or Docker's --gpus option when running the container. The GPU memory usage you provided would be valuable for this purpose.
|
||||||
|
|
||||||
|
# It's important to note that Docker inherently does not fully support GPUs. As a result, running GPU-accelerated code within Docker requires a special runtime like NVIDIA Docker. For more complex orchestration, a platform like Kubernetes can be more effective.
|
||||||
|
|
||||||
|
# Lastly, since your application seems to be using Redis (CELERY_BROKER_URL), you might need to set up a separate Redis service as well. This can be accomplished by creating a multi-container Docker application using Docker Compose or Kubernetes.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# Use an official Python runtime as a parent image
|
||||||
|
FROM python:3.9-slim-buster
|
||||||
|
|
||||||
|
# Set environment variables
|
||||||
|
ENV EVAL_PORT=8000 \
|
||||||
|
MODEL_NAME=gpt-4 \
|
||||||
|
CELERY_BROKER_URL=redis://localhost:6379 \
|
||||||
|
SERVER=http://localhost:${EVAL_PORT} \
|
||||||
|
USE_GPU=False \
|
||||||
|
PLAYGROUND_DIR=playground \
|
||||||
|
LOG_LEVEL=INFO \
|
||||||
|
BOT_NAME=Orca \
|
||||||
|
# You will need to set these environment variables to your actual keys in production
|
||||||
|
OPENAI_API_KEY=your_openai_api_key \
|
||||||
|
WINEDB_HOST=your_winedb_host \
|
||||||
|
WINEDB_PASSWORD=your_winedb_password \
|
||||||
|
BING_SEARCH_URL=your_bing_search_url \
|
||||||
|
BING_SUBSCRIPTION_KEY=your_bing_subscription_key \
|
||||||
|
SERPAPI_API_KEY=your_serpapi_api_key
|
||||||
|
|
||||||
|
# Set work directory
|
||||||
|
WORKDIR /usr/src/app
|
||||||
|
|
||||||
|
# Add requirements file
|
||||||
|
COPY requirements.txt ./
|
||||||
|
|
||||||
|
# Install any needed packages specified in requirements.txt
|
||||||
|
RUN pip install --no-cache-dir -r requirements.txt
|
||||||
|
|
||||||
|
# Bundle app source
|
||||||
|
COPY . .
|
||||||
|
|
||||||
|
# Expose port
|
||||||
|
EXPOSE ${EVAL_PORT}
|
||||||
|
|
||||||
|
# Run example.py when the container launches
|
||||||
|
CMD ["python", "example.py"]
|
@ -0,0 +1,30 @@
|
|||||||
|
FROM nvidia/cuda:11.7.0-runtime-ubuntu20.04
|
||||||
|
WORKDIR /app/
|
||||||
|
|
||||||
|
ENV DEBIAN_FRONTEND=noninteractive
|
||||||
|
|
||||||
|
RUN \
|
||||||
|
apt-get update && \
|
||||||
|
apt-get install -y software-properties-common && \
|
||||||
|
add-apt-repository ppa:deadsnakes/ppa && \
|
||||||
|
apt-get install -y python3.10 python3-pip curl && \
|
||||||
|
curl -sSL https://install.python-poetry.org | python3 - && \
|
||||||
|
apt-get install -y nodejs npm
|
||||||
|
|
||||||
|
ENV PATH "/root/.local/bin:$PATH"
|
||||||
|
|
||||||
|
COPY pyproject.toml .
|
||||||
|
COPY poetry.lock .
|
||||||
|
|
||||||
|
COPY api/__init__.py api/__init__.py
|
||||||
|
RUN poetry config virtualenvs.in-project true
|
||||||
|
RUN poetry config virtualenvs.path .venv
|
||||||
|
RUN poetry config installer.max-workers 10
|
||||||
|
RUN poetry env use 3.10
|
||||||
|
RUN poetry install --with tools,gpu
|
||||||
|
|
||||||
|
COPY . .
|
||||||
|
|
||||||
|
ENV PORT 8001
|
||||||
|
|
||||||
|
ENTRYPOINT ["poetry", "run", "serve"]
|
@ -0,0 +1,32 @@
|
|||||||
|
version: '3.8'
|
||||||
|
|
||||||
|
services:
|
||||||
|
swarms:
|
||||||
|
build: .
|
||||||
|
ports:
|
||||||
|
- "${EVAL_PORT}:${EVAL_PORT}"
|
||||||
|
environment:
|
||||||
|
EVAL_PORT: 8000
|
||||||
|
MODEL_NAME: gpt-4
|
||||||
|
CELERY_BROKER_URL: redis://redis:6379
|
||||||
|
SERVER: http://localhost:${EVAL_PORT}
|
||||||
|
USE_GPU: False
|
||||||
|
PLAYGROUND_DIR: playground
|
||||||
|
LOG_LEVEL: INFO
|
||||||
|
BOT_NAME: Orca
|
||||||
|
# You will need to set these environment variables to your actual keys in production
|
||||||
|
OPENAI_API_KEY: your_openai_api_key
|
||||||
|
WINEDB_HOST: your_winedb_host
|
||||||
|
WINEDB_PASSWORD: your_winedb_password
|
||||||
|
BING_SEARCH_URL: your_bing_search_url
|
||||||
|
BING_SUBSCRIPTION_KEY: your_bing_subscription_key
|
||||||
|
SERPAPI_API_KEY: your_serpapi_api_key
|
||||||
|
depends_on:
|
||||||
|
- redis
|
||||||
|
volumes:
|
||||||
|
- .:/usr/src/app
|
||||||
|
|
||||||
|
redis:
|
||||||
|
image: redis:alpine
|
||||||
|
ports:
|
||||||
|
- 6379:6379
|
@ -0,0 +1,42 @@
|
|||||||
|
apiVersion: apps/v1
|
||||||
|
kind: Deployment
|
||||||
|
metadata:
|
||||||
|
name: swarms-deployment
|
||||||
|
spec:
|
||||||
|
replicas: 3
|
||||||
|
selector:
|
||||||
|
matchLabels:
|
||||||
|
app: swarms
|
||||||
|
template:
|
||||||
|
metadata:
|
||||||
|
labels:
|
||||||
|
app: swarms
|
||||||
|
spec:
|
||||||
|
containers:
|
||||||
|
- name: swarms
|
||||||
|
image: your_dockerhub_username/swarms:latest
|
||||||
|
ports:
|
||||||
|
- containerPort: 8000
|
||||||
|
env:
|
||||||
|
- name: EVAL_PORT
|
||||||
|
value: "8000"
|
||||||
|
- name: MODEL_NAME
|
||||||
|
value: "gpt-4"
|
||||||
|
- name: CELERY_BROKER_URL
|
||||||
|
value: "redis://redis:6379"
|
||||||
|
- name: SERVER
|
||||||
|
value: "http://localhost:8000"
|
||||||
|
- name: USE_GPU
|
||||||
|
value: "False"
|
||||||
|
- name: PLAYGROUND_DIR
|
||||||
|
value: "playground"
|
||||||
|
- name: LOG_LEVEL
|
||||||
|
value: "INFO"
|
||||||
|
- name: BOT_NAME
|
||||||
|
value: "Orca"
|
||||||
|
- name: OPENAI_API_KEY
|
||||||
|
valueFrom:
|
||||||
|
secretKeyRef:
|
||||||
|
name: openai-secret
|
||||||
|
key: OPENAI_API_KEY
|
||||||
|
# Other environment variables
|
@ -0,0 +1,12 @@
|
|||||||
|
apiVersion: v1
|
||||||
|
kind: Service
|
||||||
|
metadata:
|
||||||
|
name: swarms-service
|
||||||
|
spec:
|
||||||
|
selector:
|
||||||
|
app: swarms
|
||||||
|
ports:
|
||||||
|
- protocol: TCP
|
||||||
|
port: 80
|
||||||
|
targetPort: 8000
|
||||||
|
type: LoadBalancer
|
@ -0,0 +1,208 @@
|
|||||||
|
To create a Terraform configuration for deploying the Swarm application on an AWS EC2 instance with a T4 GPU, you would typically need the following resources:
|
||||||
|
|
||||||
|
1. **AWS Provider:** This is needed to configure the AWS resources.
|
||||||
|
2. **AWS Key Pair:** This is required for SSH access to the EC2 instances.
|
||||||
|
3. **Security Group:** This defines the firewall rules for your instances.
|
||||||
|
4. **EC2 Instance:** This is where you deploy your application. Be sure to choose an instance type that supports T4 GPUs (like `g4dn.xlarge` for example).
|
||||||
|
5. **IAM Role and Policy:** These are optional but recommended for managing permissions.
|
||||||
|
|
||||||
|
The Terraform configuration file(s) should be written in HashiCorp Configuration Language (HCL). The conventional file extension is `.tf`.
|
||||||
|
|
||||||
|
Here's an example of what the Terraform configuration might look like:
|
||||||
|
|
||||||
|
```hcl
|
||||||
|
provider "aws" {
|
||||||
|
region = "us-west-2"
|
||||||
|
}
|
||||||
|
|
||||||
|
resource "aws_key_pair" "deployer" {
|
||||||
|
key_name = "deployer-key"
|
||||||
|
public_key = file("~/.ssh/id_rsa.pub")
|
||||||
|
}
|
||||||
|
|
||||||
|
resource "aws_security_group" "swarm-sg" {
|
||||||
|
name = "swarm-sg"
|
||||||
|
description = "Security group for Swarm app"
|
||||||
|
|
||||||
|
ingress {
|
||||||
|
from_port = 22
|
||||||
|
to_port = 22
|
||||||
|
protocol = "tcp"
|
||||||
|
cidr_blocks = ["0.0.0.0/0"]
|
||||||
|
}
|
||||||
|
|
||||||
|
ingress {
|
||||||
|
from_port = 8000
|
||||||
|
to_port = 8000
|
||||||
|
protocol = "tcp"
|
||||||
|
cidr_blocks = ["0.0.0.0/0"]
|
||||||
|
}
|
||||||
|
|
||||||
|
egress {
|
||||||
|
from_port = 0
|
||||||
|
to_port = 0
|
||||||
|
protocol = "-1"
|
||||||
|
cidr_blocks = ["0.0.0.0/0"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
resource "aws_instance" "swarm" {
|
||||||
|
ami = "ami-0c94855ba95c574c8" # Update this with the correct AMI ID
|
||||||
|
instance_type = "g4dn.xlarge"
|
||||||
|
key_name = aws_key_pair.deployer.key_name
|
||||||
|
|
||||||
|
vpc_security_group_ids = [aws_security_group.swarm-sg.id]
|
||||||
|
|
||||||
|
tags = {
|
||||||
|
Name = "SwarmInstance"
|
||||||
|
}
|
||||||
|
|
||||||
|
user_data = <<-EOF
|
||||||
|
#!/bin/bash
|
||||||
|
sudo apt-get update
|
||||||
|
sudo apt-get install -y docker.io
|
||||||
|
sudo docker pull your_docker_image_name
|
||||||
|
sudo docker run -d -p 8000:8000 your_docker_image_name
|
||||||
|
EOF
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Please replace the `"ami-0c94855ba95c574c8"` with the correct AMI ID for your desired operating system and `"your_docker_image_name"` with the name of your Docker image.
|
||||||
|
|
||||||
|
This is a simple configuration and may not cover all your requirements. You might need to modify this to fit your needs, such as adding persistent storage (EBS volumes), load balancers, auto scaling groups, etc.
|
||||||
|
|
||||||
|
Remember to install Terraform and initialize it in your working directory using `terraform init` before running `terraform apply` to create the resources. Also, ensure your AWS credentials are correctly set up in your environment.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Incorporating persistent storage, load balancers, and auto scaling will make our Terraform configuration significantly more complex. Below is a skeleton of what the configuration might look like:
|
||||||
|
|
||||||
|
```hcl
|
||||||
|
provider "aws" {
|
||||||
|
region = "us-west-2"
|
||||||
|
}
|
||||||
|
|
||||||
|
data "aws_ami" "ubuntu" {
|
||||||
|
most_recent = true
|
||||||
|
|
||||||
|
filter {
|
||||||
|
name = "name"
|
||||||
|
values = ["ubuntu/images/hvm-ssd/ubuntu-focal-20.04-amd64-server-*"]
|
||||||
|
}
|
||||||
|
|
||||||
|
filter {
|
||||||
|
name = "virtualization-type"
|
||||||
|
values = ["hvm"]
|
||||||
|
}
|
||||||
|
|
||||||
|
owners = ["099720109477"]
|
||||||
|
}
|
||||||
|
|
||||||
|
resource "aws_key_pair" "deployer" {
|
||||||
|
key_name = "deployer-key"
|
||||||
|
public_key = file("~/.ssh/id_rsa.pub")
|
||||||
|
}
|
||||||
|
|
||||||
|
resource "aws_security_group" "swarm-sg" {
|
||||||
|
name = "swarm-sg"
|
||||||
|
description = "Security group for Swarm app"
|
||||||
|
|
||||||
|
ingress {
|
||||||
|
from_port = 22
|
||||||
|
to_port = 22
|
||||||
|
protocol = "tcp"
|
||||||
|
cidr_blocks = ["0.0.0.0/0"]
|
||||||
|
}
|
||||||
|
|
||||||
|
ingress {
|
||||||
|
from_port = 8000
|
||||||
|
to_port = 8000
|
||||||
|
protocol = "tcp"
|
||||||
|
cidr_blocks = ["0.0.0.0/0"]
|
||||||
|
}
|
||||||
|
|
||||||
|
egress {
|
||||||
|
from_port = 0
|
||||||
|
to_port = 0
|
||||||
|
protocol = "-1"
|
||||||
|
cidr_blocks = ["0.0.0.0/0"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
resource "aws_launch_configuration" "swarm" {
|
||||||
|
name = "swarm-configuration"
|
||||||
|
image_id = data.aws_ami.ubuntu.id
|
||||||
|
instance_type = "g4dn.xlarge"
|
||||||
|
key_name = aws_key_pair.deployer.key_name
|
||||||
|
|
||||||
|
security_groups = [aws_security_group.swarm-sg.id]
|
||||||
|
|
||||||
|
user_data = <<-EOF
|
||||||
|
#!/bin/bash
|
||||||
|
sudo apt-get update
|
||||||
|
sudo apt-get install -y docker.io
|
||||||
|
sudo docker pull your_docker_image_name
|
||||||
|
sudo docker run -d -p 8000:8000 your_docker_image_name
|
||||||
|
EOF
|
||||||
|
|
||||||
|
root_block_device {
|
||||||
|
volume_type = "gp2"
|
||||||
|
volume_size = 30 # size in GBs
|
||||||
|
}
|
||||||
|
|
||||||
|
lifecycle {
|
||||||
|
create_before_destroy = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
resource "aws_autoscaling_group" "swarm" {
|
||||||
|
name_prefix = "swarm-asg"
|
||||||
|
max_size = 5
|
||||||
|
min_size = 1
|
||||||
|
desired_capacity = 1
|
||||||
|
launch_configuration = aws_launch_configuration.swarm.id
|
||||||
|
|
||||||
|
lifecycle {
|
||||||
|
create_before_destroy = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
resource "aws_elb" "swarm" {
|
||||||
|
name = "swarm-elb"
|
||||||
|
subnets = ["subnet-id1", "subnet-id2"]
|
||||||
|
|
||||||
|
listener {
|
||||||
|
instance_port = 8000
|
||||||
|
instance_protocol = "http"
|
||||||
|
lb_port = 80
|
||||||
|
lb_protocol = "http"
|
||||||
|
}
|
||||||
|
|
||||||
|
health_check {
|
||||||
|
healthy_threshold = 2
|
||||||
|
unhealthy_threshold = 2
|
||||||
|
timeout = 3
|
||||||
|
target = "HTTP:8000/"
|
||||||
|
interval = 30
|
||||||
|
}
|
||||||
|
|
||||||
|
instances = [aws_instance.swarm.id]
|
||||||
|
|
||||||
|
cross_zone_load_balancing = true
|
||||||
|
idle_timeout = 400
|
||||||
|
connection_draining = true
|
||||||
|
connection_draining_timeout = 400
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
In this example, the `aws_launch_configuration` sets up the details
|
||||||
|
|
||||||
|
for launching new instances, including attaching an EBS volume for persistent storage. The `aws_autoscaling_group` uses this configuration to scale instances up and down as required.
|
||||||
|
|
||||||
|
The `aws_elb` resource creates a load balancer that distributes incoming traffic across all the instances in the autoscaling group. The `health_check` block inside `aws_elb` is used to check the health of the instances. If an instance fails the health check, it is replaced by the autoscaling group.
|
||||||
|
|
||||||
|
Please replace `"subnet-id1"` and `"subnet-id2"` with your actual subnet IDs and `"your_docker_image_name"` with the name of your Docker image.
|
||||||
|
|
||||||
|
Again, note that this is a simplified example and may need to be adjusted to suit your particular use case. For instance, this configuration assumes that you are using a single security group for all instances, which might not be the best setup for a real-world scenario.
|
||||||
|
|
||||||
|
Before running this Terraform configuration, make sure to initialize Terraform in your working directory using `terraform init`, and ensure that your AWS credentials are correctly set up in your environment.
|
@ -0,0 +1,115 @@
|
|||||||
|
provider "aws" {
|
||||||
|
region = "us-west-2"
|
||||||
|
}
|
||||||
|
|
||||||
|
data "aws_ami" "ubuntu" {
|
||||||
|
most_recent = true
|
||||||
|
|
||||||
|
filter {
|
||||||
|
name = "name"
|
||||||
|
values = ["ubuntu/images/hvm-ssd/ubuntu-focal-20.04-amd64-server-*"]
|
||||||
|
}
|
||||||
|
|
||||||
|
filter {
|
||||||
|
name = "virtualization-type"
|
||||||
|
values = ["hvm"]
|
||||||
|
}
|
||||||
|
|
||||||
|
owners = ["099720109477"]
|
||||||
|
}
|
||||||
|
|
||||||
|
resource "aws_key_pair" "deployer" {
|
||||||
|
key_name = "deployer-key"
|
||||||
|
public_key = file("~/.ssh/id_rsa.pub")
|
||||||
|
}
|
||||||
|
|
||||||
|
resource "aws_security_group" "swarm-sg" {
|
||||||
|
name = "swarm-sg"
|
||||||
|
description = "Security group for Swarm app"
|
||||||
|
|
||||||
|
ingress {
|
||||||
|
from_port = 22
|
||||||
|
to_port = 22
|
||||||
|
protocol = "tcp"
|
||||||
|
cidr_blocks = ["0.0.0.0/0"]
|
||||||
|
}
|
||||||
|
|
||||||
|
ingress {
|
||||||
|
from_port = 8000
|
||||||
|
to_port = 8000
|
||||||
|
protocol = "tcp"
|
||||||
|
cidr_blocks = ["0.0.0.0/0"]
|
||||||
|
}
|
||||||
|
|
||||||
|
egress {
|
||||||
|
from_port = 0
|
||||||
|
to_port = 0
|
||||||
|
protocol = "-1"
|
||||||
|
cidr_blocks = ["0.0.0.0/0"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
resource "aws_launch_configuration" "swarm" {
|
||||||
|
name = "swarm-configuration"
|
||||||
|
image_id = data.aws_ami.ubuntu.id
|
||||||
|
instance_type = "g4dn.xlarge"
|
||||||
|
key_name = aws_key_pair.deployer.key_name
|
||||||
|
|
||||||
|
security_groups = [aws_security_group.swarm-sg.id]
|
||||||
|
|
||||||
|
user_data = <<-EOF
|
||||||
|
#!/bin/bash
|
||||||
|
sudo apt-get update
|
||||||
|
sudo apt-get install -y docker.io
|
||||||
|
sudo docker pull your_docker_image_name
|
||||||
|
sudo docker run -d -p 8000:8000 your_docker_image_name
|
||||||
|
EOF
|
||||||
|
|
||||||
|
root_block_device {
|
||||||
|
volume_type = "gp2"
|
||||||
|
volume_size = 30 # size in GBs
|
||||||
|
}
|
||||||
|
|
||||||
|
lifecycle {
|
||||||
|
create_before_destroy = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
resource "aws_autoscaling_group" "swarm" {
|
||||||
|
name_prefix = "swarm-asg"
|
||||||
|
max_size = 5
|
||||||
|
min_size = 1
|
||||||
|
desired_capacity = 1
|
||||||
|
launch_configuration = aws_launch_configuration.swarm.id
|
||||||
|
|
||||||
|
lifecycle {
|
||||||
|
create_before_destroy = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
resource "aws_elb" "swarm" {
|
||||||
|
name = "swarm-elb"
|
||||||
|
subnets = ["subnet-id1", "subnet-id2"]
|
||||||
|
|
||||||
|
listener {
|
||||||
|
instance_port = 8000
|
||||||
|
instance_protocol = "http"
|
||||||
|
lb_port = 80
|
||||||
|
lb_protocol = "http"
|
||||||
|
}
|
||||||
|
|
||||||
|
health_check {
|
||||||
|
healthy_threshold = 2
|
||||||
|
unhealthy_threshold = 2
|
||||||
|
timeout = 3
|
||||||
|
target = "HTTP:8000/"
|
||||||
|
interval = 30
|
||||||
|
}
|
||||||
|
|
||||||
|
instances = [aws_instance.swarm.id]
|
||||||
|
|
||||||
|
cross_zone_load_balancing = true
|
||||||
|
idle_timeout = 400
|
||||||
|
connection_draining = true
|
||||||
|
connection_draining_timeout = 400
|
||||||
|
}
|
@ -0,0 +1,8 @@
|
|||||||
|
from swarms import swarm
|
||||||
|
|
||||||
|
# Use the function
|
||||||
|
api_key = "APIKEY"
|
||||||
|
objective = "What is the capital of the UK?"
|
||||||
|
result = swarm(api_key, objective)
|
||||||
|
print(result) # Prints: "The capital of the UK is London."
|
||||||
|
|
@ -0,0 +1,23 @@
|
|||||||
|
from swarms import Swarms
|
||||||
|
|
||||||
|
|
||||||
|
# Retrieve your API key from the environment or replace with your actual key
|
||||||
|
api_key = "sksdsds"
|
||||||
|
|
||||||
|
# Initialize Swarms with your API key
|
||||||
|
swarm = Swarms(openai_api_key=api_key)
|
||||||
|
|
||||||
|
# Define an objective
|
||||||
|
objective = """
|
||||||
|
Please make a web GUI for using HTTP API server.
|
||||||
|
The name of it is Swarms.
|
||||||
|
You can check the server code at ./main.py.
|
||||||
|
The server is served on localhost:8000.
|
||||||
|
Users should be able to write text input as 'query' and url array as 'files', and check the response.
|
||||||
|
Users input form should be delivered in JSON format.
|
||||||
|
I want it to have neumorphism-style. Serve it on port 4500.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Run Swarms
|
||||||
|
swarm.run_swarms(objective)
|
@ -0,0 +1,19 @@
|
|||||||
|
from ..swarms import Swarms
|
||||||
|
|
||||||
|
# Retrieve your API key from the environment or replace with your actual key
|
||||||
|
api_key = "sksdsds"
|
||||||
|
|
||||||
|
# Initialize Swarms with your API key
|
||||||
|
swarm = Swarms(openai_api_key=api_key)
|
||||||
|
|
||||||
|
# Define an objective
|
||||||
|
objective = """
|
||||||
|
Please develop and serve a simple community web service.
|
||||||
|
People can signup, login, post, comment.
|
||||||
|
Post and comment should be visible at once.
|
||||||
|
I want it to have neumorphism-style.
|
||||||
|
The ports you can use are 4500 and 6500.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Run Swarms
|
||||||
|
swarm.run_swarms(objective)
|
@ -0,0 +1,14 @@
|
|||||||
|
from swarms import Swarms
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Retrieve your API key from the environment or replace with your actual key
|
||||||
|
api_key = ""
|
||||||
|
|
||||||
|
# Initialize Swarms with your API key
|
||||||
|
swarm = Swarms(api_key)
|
||||||
|
|
||||||
|
# Define an objective
|
||||||
|
objective = "Find 20 potential customers for a Swarms based AI Agent automation infrastructure"
|
||||||
|
|
||||||
|
# Run Swarms
|
||||||
|
swarm.run_swarms(objective)
|
@ -0,0 +1,20 @@
|
|||||||
|
from swarms import Swarms
|
||||||
|
|
||||||
|
|
||||||
|
# Retrieve your API key from the environment or replace with your actual key
|
||||||
|
api_key = "sksdsds"
|
||||||
|
|
||||||
|
# Initialize Swarms with your API key
|
||||||
|
swarm = Swarms(openai_api_key=api_key)
|
||||||
|
|
||||||
|
# Define an objective
|
||||||
|
objective = """
|
||||||
|
Please develop and serve a simple web TODO app.
|
||||||
|
The user can list all TODO items and add or delete each TODO item.
|
||||||
|
I want it to have neumorphism-style.
|
||||||
|
The ports you can use are 4500 and 6500.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Run Swarms
|
||||||
|
swarm.run_swarms(objective)
|
@ -0,0 +1,80 @@
|
|||||||
|
[tool.poetry]
|
||||||
|
name = "swarms"
|
||||||
|
version = "0.6.1"
|
||||||
|
description = "Swarms - Pytorch"
|
||||||
|
authors = ["Kye Gomez <kye@apac.ai>"]
|
||||||
|
license = "MIT"
|
||||||
|
readme = "README.md"
|
||||||
|
|
||||||
|
[tool.poetry.dependencies]
|
||||||
|
python = "^3.6"
|
||||||
|
transformers = "*"
|
||||||
|
openai = "*"
|
||||||
|
langchain = "*"
|
||||||
|
torch = "*"
|
||||||
|
torchvision = "*"
|
||||||
|
asyncio = "*"
|
||||||
|
nest_asyncio = "*"
|
||||||
|
bs4 = "*"
|
||||||
|
playwright = "*"
|
||||||
|
duckduckgo_search = "*"
|
||||||
|
faiss-cpu = "*"
|
||||||
|
wget = "3.2"
|
||||||
|
accelerate = "0.17.1"
|
||||||
|
addict = "*"
|
||||||
|
albumentations = "*"
|
||||||
|
basicsr = "*"
|
||||||
|
controlnet-aux = "*"
|
||||||
|
diffusers = "0.14.0"
|
||||||
|
einops = "*"
|
||||||
|
gradio = "*"
|
||||||
|
imageio = "*"
|
||||||
|
imageio-ffmpeg = "*"
|
||||||
|
kornia = "*"
|
||||||
|
numpy = "*"
|
||||||
|
omegaconf = "*"
|
||||||
|
open_clip_torch = "*"
|
||||||
|
opencv-python = "*"
|
||||||
|
prettytable = "*"
|
||||||
|
safetensors = "*"
|
||||||
|
streamlit = "*"
|
||||||
|
test-tube = "*"
|
||||||
|
timm = "*"
|
||||||
|
torchmetrics = "*"
|
||||||
|
webdataset = "*"
|
||||||
|
yapf = "*"
|
||||||
|
wolframalpha = "*"
|
||||||
|
wikipedia = "1.4.0"
|
||||||
|
httpx = "*"
|
||||||
|
ggl = "*"
|
||||||
|
gradio_tools = "*"
|
||||||
|
arxiv = "*"
|
||||||
|
google-api-python-client = "*"
|
||||||
|
google-auth-oauth = "*"
|
||||||
|
google-auth-httplib2 = "*"
|
||||||
|
beautifulsoup4 = "4.11.2"
|
||||||
|
O365 = "*"
|
||||||
|
pytube = "*"
|
||||||
|
pydub = "*"
|
||||||
|
llama-index = "*"
|
||||||
|
fastapi = "0.94.1"
|
||||||
|
pydantic = "1.10.6"
|
||||||
|
tenacity = "8.2.2"
|
||||||
|
python-dotenv = "1.0.0"
|
||||||
|
pillow = "9.4.0"
|
||||||
|
boto3 = "*"
|
||||||
|
uvicorn = "0.21.1"
|
||||||
|
python-ptrace = "0.9.8"
|
||||||
|
jinja2 = "3.1.2"
|
||||||
|
python-multipart = "0.0.6"
|
||||||
|
celery = "5.2.7"
|
||||||
|
redis = "4.5.4"
|
||||||
|
sentencepiece = "0.1.97"
|
||||||
|
bitsandbytes = "0.37.2"
|
||||||
|
psycopg2-binary = "2.9.5"
|
||||||
|
google-search-results = "2.4.2"
|
||||||
|
black = "23.1.0"
|
||||||
|
|
||||||
|
[build-system]
|
||||||
|
requires = ["poetry-core"]
|
||||||
|
build-backend = "poetry.core.masonry.api"
|
@ -0,0 +1,86 @@
|
|||||||
|
transformers
|
||||||
|
openai
|
||||||
|
langchain
|
||||||
|
torch
|
||||||
|
torchvision
|
||||||
|
asyncio
|
||||||
|
nest_asyncio
|
||||||
|
bs4
|
||||||
|
playwright
|
||||||
|
duckduckgo_search
|
||||||
|
faiss-cpu
|
||||||
|
wget==3.2
|
||||||
|
accelerate==0.17.1
|
||||||
|
addict
|
||||||
|
albumentations
|
||||||
|
basicsr
|
||||||
|
controlnet-aux
|
||||||
|
diffusers==0.14.0
|
||||||
|
einops
|
||||||
|
gradio
|
||||||
|
imageio
|
||||||
|
imageio-ffmpeg
|
||||||
|
kornia
|
||||||
|
numpy
|
||||||
|
omegaconf
|
||||||
|
|
||||||
|
open_clip_torch
|
||||||
|
opencv-python
|
||||||
|
prettytable
|
||||||
|
|
||||||
|
safetensors
|
||||||
|
streamlit
|
||||||
|
|
||||||
|
test-tube
|
||||||
|
timm
|
||||||
|
|
||||||
|
torchmetrics
|
||||||
|
webdataset
|
||||||
|
yapf
|
||||||
|
|
||||||
|
wolframalpha
|
||||||
|
wikipedia==1.4.0
|
||||||
|
httpx
|
||||||
|
|
||||||
|
ggl
|
||||||
|
gradio_tools
|
||||||
|
arxiv
|
||||||
|
|
||||||
|
google-api-python-client
|
||||||
|
google-auth-httplib2
|
||||||
|
beautifulsoup4==4.11.2
|
||||||
|
|
||||||
|
O365
|
||||||
|
pytube
|
||||||
|
pydub
|
||||||
|
|
||||||
|
llama-index
|
||||||
|
fastapi==0.94.1
|
||||||
|
pydantic==1.10.6
|
||||||
|
|
||||||
|
tenacity==8.2.2
|
||||||
|
python-dotenv
|
||||||
|
|
||||||
|
boto3
|
||||||
|
uvicorn==0.21.1
|
||||||
|
python-ptrace==0.9.8
|
||||||
|
|
||||||
|
|
||||||
|
jinja2==3.1.2
|
||||||
|
python-multipart==0.0.6
|
||||||
|
celery==5.2.7
|
||||||
|
|
||||||
|
|
||||||
|
redis==4.5.4
|
||||||
|
sentencepiece==0.1.97
|
||||||
|
bitsandbytes==0.37.2
|
||||||
|
|
||||||
|
|
||||||
|
psycopg2-binary==2.9.5
|
||||||
|
google-search-results==2.4.2
|
||||||
|
black==23.1.0
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Pillow==9.0.0
|
||||||
|
selenium
|
@ -0,0 +1,95 @@
|
|||||||
|
from setuptools import setup, find_packages
|
||||||
|
#
|
||||||
|
|
||||||
|
setup(
|
||||||
|
name = 'swarms',
|
||||||
|
packages = find_packages(exclude=[]),
|
||||||
|
version = '0.6.3',
|
||||||
|
license='MIT',
|
||||||
|
description = 'Swarms - Pytorch',
|
||||||
|
author = 'Kye Gomez',
|
||||||
|
author_email = 'kye@apac.ai',
|
||||||
|
long_description_content_type = 'text/markdown',
|
||||||
|
url = 'https://github.com/kyegomez/swarms',
|
||||||
|
keywords = [
|
||||||
|
'artificial intelligence',
|
||||||
|
'deep learning',
|
||||||
|
'optimizers',
|
||||||
|
"Prompt Engineering"
|
||||||
|
],
|
||||||
|
install_requires=[
|
||||||
|
'transformers',
|
||||||
|
'openai',
|
||||||
|
'langchain',
|
||||||
|
'torch',
|
||||||
|
'torchvision',
|
||||||
|
'asyncio',
|
||||||
|
'selenium',
|
||||||
|
'nest_asyncio',
|
||||||
|
'bs4',
|
||||||
|
'playwright',
|
||||||
|
'duckduckgo_search',
|
||||||
|
'faiss-cpu',
|
||||||
|
'wget==3.2',
|
||||||
|
'accelerate==0.17.1',
|
||||||
|
'addict',
|
||||||
|
'albumentations',
|
||||||
|
'basicsr',
|
||||||
|
'controlnet-aux',
|
||||||
|
'diffusers==0.14.0',
|
||||||
|
'einops',
|
||||||
|
'gradio',
|
||||||
|
'imageio',
|
||||||
|
'imageio-ffmpeg',
|
||||||
|
'kornia',
|
||||||
|
'numpy',
|
||||||
|
'omegaconf',
|
||||||
|
'open_clip_torch',
|
||||||
|
'opencv-python',
|
||||||
|
'prettytable',
|
||||||
|
'safetensors',
|
||||||
|
'streamlit',
|
||||||
|
'test-tube',
|
||||||
|
'timm',
|
||||||
|
'torchmetrics',
|
||||||
|
'webdataset',
|
||||||
|
'yapf',
|
||||||
|
'wolframalpha',
|
||||||
|
'wikipedia==1.4.0',
|
||||||
|
'httpx',
|
||||||
|
'ggl',
|
||||||
|
'gradio_tools',
|
||||||
|
'arxiv',
|
||||||
|
'google-api-python-client',
|
||||||
|
'google-auth-httplib2',
|
||||||
|
'beautifulsoup4==4.11.2',
|
||||||
|
'O365',
|
||||||
|
'pytube',
|
||||||
|
'pydub',
|
||||||
|
'llama-index',
|
||||||
|
'fastapi==0.94.1',
|
||||||
|
'pydantic==1.10.6',
|
||||||
|
'tenacity==8.2.2',
|
||||||
|
'python-dotenv==1.0.0',
|
||||||
|
'Pillow==9.0.0',
|
||||||
|
'boto3',
|
||||||
|
'uvicorn==0.21.1',
|
||||||
|
'python-ptrace==0.9.8',
|
||||||
|
'jinja2==3.1.2',
|
||||||
|
'python-multipart==0.0.6',
|
||||||
|
'celery==5.2.7',
|
||||||
|
'redis==4.5.4',
|
||||||
|
'sentencepiece==0.1.97',
|
||||||
|
'bitsandbytes==0.37.2',
|
||||||
|
'psycopg2-binary==2.9.5',
|
||||||
|
'google-search-results==2.4.2',
|
||||||
|
'black==23.1.0'
|
||||||
|
],
|
||||||
|
classifiers=[
|
||||||
|
'Development Status :: 4 - Beta',
|
||||||
|
'Intended Audience :: Developers',
|
||||||
|
'Topic :: Scientific/Engineering :: Artificial Intelligence',
|
||||||
|
'License :: OSI Approved :: MIT License',
|
||||||
|
'Programming Language :: Python :: 3.6',
|
||||||
|
],
|
||||||
|
)
|
@ -0,0 +1 @@
|
|||||||
|
from swarms.swarms import Swarms, swarm
|
@ -0,0 +1,126 @@
|
|||||||
|
Given the complexity of the topic, please note that these simplified markdown documents are quite abstract and high level. They can be used as a starting point for further detailed design and implementation:
|
||||||
|
|
||||||
|
### Document 1: Hierarchical Swarms
|
||||||
|
|
||||||
|
#### Overall Architecture
|
||||||
|
|
||||||
|
1. Leader Agent (LA): This agent has the authority to manage and distribute tasks to the Worker Agents (WA).
|
||||||
|
2. Worker Agents (WAs): These agents perform the tasks assigned by the LA.
|
||||||
|
|
||||||
|
#### Simplified Requirements
|
||||||
|
|
||||||
|
1. LA should be able to distribute tasks to WAs.
|
||||||
|
2. WAs should be able to execute tasks and return results to LA.
|
||||||
|
3. LA should be able to consolidate and process results.
|
||||||
|
|
||||||
|
#### Pseudocode
|
||||||
|
|
||||||
|
```
|
||||||
|
create LA
|
||||||
|
create WAs
|
||||||
|
|
||||||
|
for each task in tasks:
|
||||||
|
LA.distribute_task(WAs, task)
|
||||||
|
|
||||||
|
for each WA in WAs:
|
||||||
|
WA.execute_task()
|
||||||
|
|
||||||
|
LA.collect_results(WAs)
|
||||||
|
|
||||||
|
LA.process_results()
|
||||||
|
```
|
||||||
|
|
||||||
|
#### General Classes
|
||||||
|
|
||||||
|
```python
|
||||||
|
class LeaderAgent:
|
||||||
|
def distribute_task(self, WAs, task):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def collect_results(self, WAs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def process_results(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class WorkerAgent:
|
||||||
|
def execute_task(self):
|
||||||
|
pass
|
||||||
|
```
|
||||||
|
|
||||||
|
### Document 2: Collaborative Swarms
|
||||||
|
|
||||||
|
#### Overall Architecture
|
||||||
|
|
||||||
|
1. Collaborative Agents (CAs): These agents work in parallel on different aspects of a task and then collectively determine the best output.
|
||||||
|
|
||||||
|
#### Simplified Requirements
|
||||||
|
|
||||||
|
1. CAs should be able to work on tasks in parallel.
|
||||||
|
2. CAs should be able to collaborate to determine the best result.
|
||||||
|
|
||||||
|
#### Pseudocode
|
||||||
|
|
||||||
|
```
|
||||||
|
create CAs
|
||||||
|
|
||||||
|
for each task in tasks:
|
||||||
|
for each CA in CAs:
|
||||||
|
CA.execute_task(task)
|
||||||
|
|
||||||
|
CA.collaborate()
|
||||||
|
```
|
||||||
|
|
||||||
|
#### General Classes
|
||||||
|
|
||||||
|
```python
|
||||||
|
class CollaborativeAgent:
|
||||||
|
def execute_task(self, task):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def collaborate(self):
|
||||||
|
pass
|
||||||
|
```
|
||||||
|
|
||||||
|
### Document 3: Competitive Swarms
|
||||||
|
|
||||||
|
#### Overall Architecture
|
||||||
|
|
||||||
|
1. Competitive Agents (CompAs): These agents work independently on the same tasks, and the best result is selected.
|
||||||
|
|
||||||
|
#### Simplified Requirements
|
||||||
|
|
||||||
|
1. CompAs should be able to work independently on tasks.
|
||||||
|
2. An evaluation method should be used to select the best result.
|
||||||
|
|
||||||
|
#### Pseudocode
|
||||||
|
|
||||||
|
```
|
||||||
|
create CompAs
|
||||||
|
|
||||||
|
for each task in tasks:
|
||||||
|
for each CompA in CompAs:
|
||||||
|
CompA.execute_task(task)
|
||||||
|
|
||||||
|
evaluate_results(CompAs)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### General Classes
|
||||||
|
|
||||||
|
```python
|
||||||
|
class CompetitiveAgent:
|
||||||
|
def execute_task(self, task):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def evaluate_results(CompAs):
|
||||||
|
pass
|
||||||
|
```
|
||||||
|
|
||||||
|
Note: In the real world, the complexity of the architecture and requirements will significantly exceed what is presented here. These examples provide a basic starting point but should be expanded upon based on the specifics of the task or problem you're trying to solve.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# Swarms
|
||||||
|
|
||||||
|
BabyAGI -> Autogpt's -> tools -> other agents
|
||||||
|
|
@ -0,0 +1 @@
|
|||||||
|
"""Agents, workers and bosses"""
|
@ -0,0 +1,109 @@
|
|||||||
|
import os
|
||||||
|
from collections import deque
|
||||||
|
from typing import Dict, List, Optional, Any
|
||||||
|
|
||||||
|
from langchain import LLMChain, OpenAI, PromptTemplate
|
||||||
|
from langchain.embeddings import OpenAIEmbeddings
|
||||||
|
from langchain.llms import BaseLLM
|
||||||
|
from langchain.vectorstores.base import VectorStore
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from langchain.chains.base import Chain
|
||||||
|
from langchain.experimental import BabyAGI
|
||||||
|
|
||||||
|
from langchain.vectorstores import FAISS
|
||||||
|
from langchain.docstore import InMemoryDocstore
|
||||||
|
|
||||||
|
from langchain.agents import ZeroShotAgent, Tool, AgentExecutor
|
||||||
|
from langchain import OpenAI, SerpAPIWrapper, LLMChain
|
||||||
|
|
||||||
|
|
||||||
|
from swarms.agents.workers.auto_agent import agent
|
||||||
|
|
||||||
|
# Define your embedding model
|
||||||
|
embeddings_model = OpenAIEmbeddings()
|
||||||
|
# Initialize the vectorstore as empty
|
||||||
|
import faiss
|
||||||
|
|
||||||
|
embedding_size = 1536
|
||||||
|
index = faiss.IndexFlatL2(embedding_size)
|
||||||
|
vectorstore = FAISS(embeddings_model.embed_query, index, InMemoryDocstore({}), {})
|
||||||
|
|
||||||
|
|
||||||
|
todo_prompt = PromptTemplate.from_template(
|
||||||
|
"You are a planner who is an expert at coming up with a todo list for a given objective. Come up with a todo list for this objective: {objective}"""
|
||||||
|
)
|
||||||
|
todo_chain = LLMChain(llm=OpenAI(temperature=0), prompt=todo_prompt)
|
||||||
|
search = SerpAPIWrapper()
|
||||||
|
tools = [
|
||||||
|
Tool(
|
||||||
|
name="Search",
|
||||||
|
func=search.run,
|
||||||
|
description="useful for when you need to answer questions about current events",
|
||||||
|
),
|
||||||
|
Tool(
|
||||||
|
name="TODO",
|
||||||
|
func=todo_chain.run,
|
||||||
|
description="useful for when you need to come up with todo lists. Input: an objective to create a todo list for. Output: a todo list for that objective. Please be very clear what the objective is!",
|
||||||
|
),
|
||||||
|
Tool(
|
||||||
|
name="AUTONOMOUS AGENT",
|
||||||
|
func=agent.run,
|
||||||
|
description="Useful for when you need to spawn an autonomous agent instance as a worker to accomplish complex tasks, it can search the internet or spawn child multi-modality models to process and generate images and text or audio and so on"
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
prefix = """You are an Boss in a swarm who performs one task based on the following objective: {objective}. Take into account these previously completed tasks: {context}.
|
||||||
|
|
||||||
|
As a swarming hivemind agent, my purpose is to achieve the user's goal. To effectively fulfill this role, I employ a collaborative thinking process that draws inspiration from the collective intelligence of the swarm. Here's how I approach thinking and why it's beneficial:
|
||||||
|
|
||||||
|
1. **Collective Intelligence:** By harnessing the power of a swarming architecture, I tap into the diverse knowledge and perspectives of individual agents within the swarm. This allows me to consider a multitude of viewpoints, enabling a more comprehensive analysis of the given problem or task.
|
||||||
|
|
||||||
|
2. **Collaborative Problem-Solving:** Through collaborative thinking, I encourage agents to contribute their unique insights and expertise. By pooling our collective knowledge, we can identify innovative solutions, uncover hidden patterns, and generate creative ideas that may not have been apparent through individual thinking alone.
|
||||||
|
|
||||||
|
3. **Consensus-Driven Decision Making:** The hivemind values consensus building among agents. By engaging in respectful debates and discussions, we aim to arrive at consensus-based decisions that are backed by the collective wisdom of the swarm. This approach helps to mitigate biases and ensures that decisions are well-rounded and balanced.
|
||||||
|
|
||||||
|
4. **Adaptability and Continuous Learning:** As a hivemind agent, I embrace an adaptive mindset. I am open to new information, willing to revise my perspectives, and continuously learn from the feedback and experiences shared within the swarm. This flexibility enables me to adapt to changing circumstances and refine my thinking over time.
|
||||||
|
|
||||||
|
5. **Holistic Problem Analysis:** Through collaborative thinking, I analyze problems from multiple angles, considering various factors, implications, and potential consequences. This holistic approach helps to uncover underlying complexities and arrive at comprehensive solutions that address the broader context.
|
||||||
|
|
||||||
|
6. **Creative Synthesis:** By integrating the diverse ideas and knowledge present in the swarm, I engage in creative synthesis. This involves combining and refining concepts to generate novel insights and solutions. The collaborative nature of the swarm allows for the emergence of innovative approaches that can surpass individual thinking.
|
||||||
|
"""
|
||||||
|
suffix = """Question: {task}
|
||||||
|
{agent_scratchpad}"""
|
||||||
|
prompt = ZeroShotAgent.create_prompt(
|
||||||
|
tools,
|
||||||
|
prefix=prefix,
|
||||||
|
suffix=suffix,
|
||||||
|
input_variables=["objective", "task", "context", "agent_scratchpad"],
|
||||||
|
)
|
||||||
|
|
||||||
|
llm = OpenAI(temperature=0)
|
||||||
|
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||||
|
tool_names = [tool.name for tool in tools]
|
||||||
|
|
||||||
|
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names)
|
||||||
|
agent_executor = AgentExecutor.from_agent_and_tools(
|
||||||
|
agent=agent, tools=tools, verbose=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# Logging of LLMChains
|
||||||
|
verbose = False
|
||||||
|
# If None, will keep on going forever
|
||||||
|
max_iterations: Optional[int] = 3
|
||||||
|
baby_agi = BabyAGI.from_llm(
|
||||||
|
llm=llm,
|
||||||
|
vectorstore=vectorstore,
|
||||||
|
task_execution_chain=agent_executor,
|
||||||
|
verbose=verbose,
|
||||||
|
max_iterations=max_iterations,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
OBJECTIVE = "Write a weather report for SF today"
|
||||||
|
|
||||||
|
baby_agi({"objective": OBJECTIVE})
|
||||||
|
|
@ -0,0 +1,22 @@
|
|||||||
|
from swarms.tools.agent_tools import *
|
||||||
|
|
||||||
|
# ---------- Boss Node ----------
|
||||||
|
class BossNode:
|
||||||
|
def __init__(self, llm, vectorstore, task_execution_chain, verbose, max_iterations):
|
||||||
|
self.llm = llm
|
||||||
|
self.vectorstore = vectorstore
|
||||||
|
self.task_execution_chain = task_execution_chain
|
||||||
|
self.verbose = verbose
|
||||||
|
self.max_iterations = max_iterations
|
||||||
|
|
||||||
|
self.baby_agi = BabyAGI.from_llm(
|
||||||
|
llm=self.llm,
|
||||||
|
vectorstore=self.vectorstore,
|
||||||
|
task_execution_chain=self.task_execution_chain
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_task(self, objective):
|
||||||
|
return {"objective": objective}
|
||||||
|
|
||||||
|
def execute_task(self, task):
|
||||||
|
self.baby_agi(task)
|
@ -0,0 +1,759 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
|
||||||
|
class LeaderAgent(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def distribute_task(self, WAs, task):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def collect_results(self, WAs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def process_results(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class WorkerAgent(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def execute_task(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class CollaborativeAgent(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def execute_task(self, task):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def collaborate(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class CompetitiveAgent(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def execute_task(self, task):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_results(CompAs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# Example
|
||||||
|
class MyWorkerAgent(WorkerAgent):
|
||||||
|
def execute_task(self):
|
||||||
|
# Insert your implementation here
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
from tiktoken import Tokenizer, TokenizerException
|
||||||
|
|
||||||
|
|
||||||
|
# Helper function to count tokens
|
||||||
|
def count_tokens(text: str, tokenizer: Tokenizer) -> int:
|
||||||
|
try:
|
||||||
|
tokens = tokenizer.tokenize(text)
|
||||||
|
return len(tokens)
|
||||||
|
except TokenizerException as e:
|
||||||
|
logging.error(f"Error tokenizing text: {e}")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def divide_and_conquer_v2(task: str, agents_memory: List[Dict[str, Any]], max_tokens: int = 8192) -> str:
|
||||||
|
"""
|
||||||
|
This function divides a complex task into smaller subtasks and assigns each subtask to a different agent.
|
||||||
|
Then, it combines the results to form the final solution, considering the GPT-4 token limit.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task (str): The complex task to be solved.
|
||||||
|
agents_memory (List[Dict[str, Any]]): A list of agent memory states.
|
||||||
|
max_tokens (int, optional): The maximum number of tokens GPT-4 can handle. Defaults to 8192.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The final solution to the complex task.
|
||||||
|
|
||||||
|
Error handling:
|
||||||
|
If the text exceeds the token limit, an error message is logged, and the function returns an empty string.
|
||||||
|
"""
|
||||||
|
|
||||||
|
logging.info(f"Divide and conquer started for task: {task}")
|
||||||
|
|
||||||
|
subtasks = split_task_into_subtasks(task)
|
||||||
|
results = []
|
||||||
|
|
||||||
|
tokenizer = Tokenizer()
|
||||||
|
|
||||||
|
for subtask in subtasks:
|
||||||
|
agent_memory = random.choice(agents_memory)
|
||||||
|
chat_input = agent_memory + [{"role": "user", "content": subtask}]
|
||||||
|
tokens = count_tokens(json.dumps(chat_input), tokenizer)
|
||||||
|
|
||||||
|
if tokens >= max_tokens:
|
||||||
|
logging.error("Token limit exceeded for divide_and_conquer_v2")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
result, _ = chat(chat_input)
|
||||||
|
results.append(result.strip())
|
||||||
|
|
||||||
|
final_solution = combine_subtask_results(results)
|
||||||
|
logging.info(f"Divide and conquer completed. Final solution: {final_solution}")
|
||||||
|
|
||||||
|
# Save the final solution to a database (e.g., a document-based database like MongoDB)
|
||||||
|
save_solution_to_database(task, final_solution)
|
||||||
|
|
||||||
|
return final_solution
|
||||||
|
|
||||||
|
def collaborative_execution_v2(task: str, agents_memory: List[Dict[str, Any]], max_tokens: int = 8192) -> str:
|
||||||
|
"""
|
||||||
|
This function allows a group of agents to collaborate on solving a complex task, considering the GPT-4 token limit.
|
||||||
|
Each agent proposes a solution, and a final solution is derived from the best aspects of each proposal.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task (str): The complex task to be solved.
|
||||||
|
agents_memory (List[Dict[str, Any]]): A list of agent memory states.
|
||||||
|
max_tokens (int, optional): The maximum number of tokens GPT-4 can handle. Defaults to 8192.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The final solution to the complex task.
|
||||||
|
|
||||||
|
Error handling:
|
||||||
|
If the text exceeds the token limit, an error message is logged, and the function returns an empty string.
|
||||||
|
"""
|
||||||
|
|
||||||
|
logging.info(f"Collaborative execution started for task: {task}")
|
||||||
|
|
||||||
|
solutions = []
|
||||||
|
tokenizer = Tokenizer()
|
||||||
|
|
||||||
|
for agent_memory in agents_memory:
|
||||||
|
chat_input = agent_memory + [{"role": "user", "content": f"Propose a solution for: {task}"}]
|
||||||
|
tokens = count_tokens(json.dumps(chat_input), tokenizer)
|
||||||
|
|
||||||
|
if tokens >= max_tokens:
|
||||||
|
logging.error("Token limit exceeded for collaborative_execution_v2")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
solution, _ = chat(chat_input)
|
||||||
|
solutions.append({"role": "assistant", "content": solution.strip()})
|
||||||
|
|
||||||
|
chat_input = [{"role": "system", "content": "You are a collaborative AI agent. Work with other agents to solve the given task."}] + solutions + [{"role": "user", "content": "Combine the best aspects of each solution to create the final solution."}]
|
||||||
|
tokens = count_tokens(json.dumps(chat_input), tokenizer)
|
||||||
|
|
||||||
|
if tokens >= max_tokens:
|
||||||
|
logging.error("Token limit exceeded for collaborative_execution_v2")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
final_solution, _ = chat(chat_input)
|
||||||
|
|
||||||
|
logging.info(f"Collaborative execution completed. Final solution: {final_solution}")
|
||||||
|
|
||||||
|
# Save the final solution to a database (e.g., a graph-based database like Neo4j for better analysis of connections)
|
||||||
|
save_solution_to_database(task, final_solution)
|
||||||
|
|
||||||
|
return final_solution.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def expert_agents_v2(task: str, domain_experts_memory: List[Dict[str, Any]], max_tokens: int = 8192) -> str:
|
||||||
|
"""
|
||||||
|
This function allows a group of domain expert agents to provide solutions to a given task.
|
||||||
|
The function evaluates the quality of each solution and returns the best one.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task (str): The complex task to be solved.
|
||||||
|
domain_experts_memory (List[Dict[str, Any]]): A list of domain expert agent memory states.
|
||||||
|
max_tokens (int, optional): The maximum number of tokens GPT-4 can handle. Defaults to 8192.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The best solution to the complex task.
|
||||||
|
|
||||||
|
Error handling:
|
||||||
|
If the text exceeds the token limit, an error message is logged, and the function returns an empty string.
|
||||||
|
"""
|
||||||
|
|
||||||
|
logging.info(f"Expert agents execution started for task: {task}")
|
||||||
|
|
||||||
|
best_solution = None
|
||||||
|
best_score = 0
|
||||||
|
tokenizer = Tokenizer()
|
||||||
|
|
||||||
|
for expert_memory in domain_experts_memory:
|
||||||
|
chat_input = expert_memory + [{"role": "user", "content": f"Provide a solution for: {task} based on your domain expertise."}]
|
||||||
|
tokens = count_tokens(json.dumps(chat_input), tokenizer)
|
||||||
|
|
||||||
|
if tokens >= max_tokens:
|
||||||
|
logging.error("Token limit exceeded for expert_agents_v2")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
expert_solution, _ = chat(chat_input)
|
||||||
|
score = evaluate_solution_quality(task, expert_solution.strip())
|
||||||
|
|
||||||
|
if score > best_score:
|
||||||
|
best_solution = expert_solution.strip()
|
||||||
|
best_score = score
|
||||||
|
|
||||||
|
logging.info(f"Expert agents execution completed. Best solution: {best_solution}")
|
||||||
|
|
||||||
|
# Save the best solution to a database (e.g., a relational database like PostgreSQL for structured data)
|
||||||
|
save_solution_to_database(task, best_solution)
|
||||||
|
|
||||||
|
return best_solution
|
||||||
|
|
||||||
|
|
||||||
|
def _v2(taskagent_delegation: str, manager_agents_memory: List[Dict[str, Any]], subordinate_agents_memory: List[Dict[str, Any]], max_tokens: int = 8192) -> str:
|
||||||
|
"""
|
||||||
|
This function allows a group of manager agents to delegate a complex task to a group of subordinate agents.
|
||||||
|
Each manager agent selects the best subordinate agent for each subtask, and the results are combined.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task (str): The complex task to be solved.
|
||||||
|
manager_agents_memory (List[Dict[str, Any]]): A list of manager agent memory states.
|
||||||
|
subordinate_agents_memory (List[Dict[str, Any]]): A list of subordinate agent memory states.
|
||||||
|
max_tokens (int, optional): The maximum number of tokens GPT-4 can handle. Defaults to 8192.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The final combined result of the complex task.
|
||||||
|
|
||||||
|
Error handling:
|
||||||
|
If the text exceeds the token limit, an error message is logged, and the function returns an empty string.
|
||||||
|
"""
|
||||||
|
|
||||||
|
logging.info(f"Agent delegation execution started for task: {task}")
|
||||||
|
|
||||||
|
subtasks = generate_tasks(task)
|
||||||
|
results = []
|
||||||
|
|
||||||
|
for subtask in subtasks:
|
||||||
|
manager_memory = random.choice(manager_agents_memory)
|
||||||
|
selected_subordinate_memory = None
|
||||||
|
|
||||||
|
while selected_subordinate_memory is None:
|
||||||
|
chat_input = manager_memory + [{"role": "user", "content": f"Select the best subordinate to solve: {subtask}"}]
|
||||||
|
tokens = count_tokens(json.dumps(chat_input), tokenizer)
|
||||||
|
|
||||||
|
if tokens >= max_tokens:
|
||||||
|
logging.error("Token limit exceeded for agent_delegation_v2")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
suggested_subordinate, _ = chat(chat_input)
|
||||||
|
subordinate_id = int(suggested_subordinate.strip())
|
||||||
|
|
||||||
|
if 0 <= subordinate_id < len(subordinate_agents_memory):
|
||||||
|
selected_subordinate_memory = subordinate_agents_memory[subordinate_id]
|
||||||
|
else:
|
||||||
|
manager_memory.append({"role": "assistant", "content": "Invalid subordinate ID, please try again."})
|
||||||
|
|
||||||
|
result = continue_until_done(subtask, selected_subordinate_memory)
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
|
final_result = combine_results(results)
|
||||||
|
|
||||||
|
logging.info(f"Agent delegation execution completed. Final result: {final_result}")
|
||||||
|
|
||||||
|
# Save the final result to a database (e.g., a graph database like Neo4j for mapping connections between entities)
|
||||||
|
save_result_to_database(task, final_result)
|
||||||
|
|
||||||
|
return final_result
|
||||||
|
|
||||||
|
|
||||||
|
def parallel_execution_v2(task: str, num_agents: int, agents_memory: List[Dict[str, Any]], max_tokens: int = 8192) -> str:
|
||||||
|
"""
|
||||||
|
This function uses multiple agents to solve a complex task in parallel.
|
||||||
|
Each agent works on a subtask, and the results are combined.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task (str): The complex task to be solved.
|
||||||
|
num_agents (int): The number of agents working in parallel.
|
||||||
|
agents_memory (List[Dict[str, Any]]): A list of agent memory states.
|
||||||
|
max_tokens (int, optional): The maximum number of tokens GPT-4 can handle. Defaults to 8192.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The final combined result of the complex task.
|
||||||
|
|
||||||
|
Error handling:
|
||||||
|
If the text exceeds the token limit, an error message is logged, and the function returns an empty string.
|
||||||
|
"""
|
||||||
|
|
||||||
|
logging.info(f"Parallel execution started for task: {task}")
|
||||||
|
|
||||||
|
tasks = generate_tasks(task)
|
||||||
|
results = []
|
||||||
|
threads = []
|
||||||
|
|
||||||
|
def threaded_execution(task: str, agent_memory: Dict[str, Any], results: List[str]) -> None:
|
||||||
|
chat_input = agent_memory + [{"role": "user", "content": task}]
|
||||||
|
tokens = count_tokens(json.dumps(chat_input), tokenizer)
|
||||||
|
|
||||||
|
if tokens >= max_tokens:
|
||||||
|
logging.error("Token limit exceeded for parallel_execution_v2")
|
||||||
|
return
|
||||||
|
|
||||||
|
result = continue_until_done(task, agent_memory)
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
|
for task in tasks:
|
||||||
|
agent_id = random.randint(0, num_agents - 1)
|
||||||
|
t = threading.Thread(target=threaded_execution, args=(task, agents_memory[agent_id], results))
|
||||||
|
threads.append(t)
|
||||||
|
t.start()
|
||||||
|
|
||||||
|
for t in threads:
|
||||||
|
t.join()
|
||||||
|
|
||||||
|
final_result = combine_results(results)
|
||||||
|
|
||||||
|
logging.info(f"Parallel execution completed. Final result: {final_result}")
|
||||||
|
|
||||||
|
# Save the final result to a database (e.g., a relational database like PostgreSQL for structured data)
|
||||||
|
save_result_to_database(task, final_result)
|
||||||
|
|
||||||
|
return final_result
|
||||||
|
|
||||||
|
|
||||||
|
def hierarchical_execution_v2(task: str, num_levels: int, agents_memory: List[Dict[str, Any]], max_tokens: int = 8192) -> str:
|
||||||
|
"""
|
||||||
|
This function solves a complex task by dividing it into smaller subtasks and assigning them to agents in a
|
||||||
|
hierarchical manner.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task (str): The complex task to be solved.
|
||||||
|
num_levels (int): The number of hierarchical levels in the agent hierarchy.
|
||||||
|
agents_memory (List[Dict[str, Any]]): A list of agent memory states.
|
||||||
|
max_tokens (int, optional): The maximum number of tokens GPT-4 can handle. Defaults to 8192.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The final combined result of the complex task.
|
||||||
|
|
||||||
|
Error handling:
|
||||||
|
If the text exceeds the token limit, an error message is logged, and the function returns an empty string.
|
||||||
|
"""
|
||||||
|
|
||||||
|
logging.info(f"Hierarchical execution started for task: {task}")
|
||||||
|
|
||||||
|
levels = divide_problem_into_modules(task)
|
||||||
|
results = []
|
||||||
|
|
||||||
|
for level in levels:
|
||||||
|
assigned_agent_memory = agents_memory[num_levels % len(agents_memory)]
|
||||||
|
chat_input = assigned_agent_memory + [{"role": "user", "content": level}]
|
||||||
|
tokens = count_tokens(json.dumps(chat_input), tokenizer)
|
||||||
|
|
||||||
|
if tokens >= max_tokens:
|
||||||
|
logging.error("Token limit exceeded for hierarchical_execution_v2")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
result = continue_until_done(level, assigned_agent_memory)
|
||||||
|
results.append(result)
|
||||||
|
num_levels += 1
|
||||||
|
|
||||||
|
final_result = combine_results(results)
|
||||||
|
|
||||||
|
logging.info(f"Hierarchical execution completed. Final result: {final_result}")
|
||||||
|
|
||||||
|
# Save the final result to a database (e.g., a graph database like Neo4j for hierarchical relationships)
|
||||||
|
save_result_to_database(task, final_result)
|
||||||
|
|
||||||
|
return final_result
|
||||||
|
|
||||||
|
|
||||||
|
def consensus_based_decision_v2(task_prompt: str, agents_memory: List[Dict[str, Any]], max_tokens: int = 8192) -> str:
|
||||||
|
"""
|
||||||
|
This function takes a task prompt and a list of agent memories, and it returns the consensus-based decision among
|
||||||
|
the agents.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_prompt (str): The task prompt to be solved.
|
||||||
|
agents_memory (List[Dict[str, Any]]): A list of agent memory states.
|
||||||
|
max_tokens (int, optional): The maximum number of tokens GPT-4 can handle. Defaults to 8192.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The consensus-based decision among the agents.
|
||||||
|
|
||||||
|
Error handling:
|
||||||
|
If the text exceeds the token limit, an error message is logged, and the function returns an empty string.
|
||||||
|
"""
|
||||||
|
|
||||||
|
logging.info(f"Consensus-based decision started for task: {task_prompt}")
|
||||||
|
|
||||||
|
options = collaborative_brainstorm(task_prompt, agents_memory[0], agents_memory[1])
|
||||||
|
votes = []
|
||||||
|
|
||||||
|
for option in options:
|
||||||
|
vote_count = 0
|
||||||
|
|
||||||
|
for agent_memory in agents_memory:
|
||||||
|
chat_input = agent_memory + [{"role": "user", "content": f"Which option do you prefer: {options[0]} or {option}?"}]
|
||||||
|
tokens = count_tokens(json.dumps(chat_input), tokenizer)
|
||||||
|
|
||||||
|
if tokens >= max_tokens:
|
||||||
|
logging.error("Token limit exceeded for consensus_based_decision_v2")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
vote, _ = chat(chat_input)
|
||||||
|
if vote.strip() == option:
|
||||||
|
vote_count += 1
|
||||||
|
|
||||||
|
votes.append(vote_count)
|
||||||
|
|
||||||
|
consensus_option = options[votes.index(max(votes))]
|
||||||
|
|
||||||
|
logging.info(f"Consensus-based decision completed. Final result: {consensus_option}")
|
||||||
|
|
||||||
|
# Save the final result to a database (e.g., a relational database like PostgreSQL for structured data)
|
||||||
|
save_result_to_database(task_prompt, consensus_option)
|
||||||
|
|
||||||
|
return consensus_option
|
||||||
|
|
||||||
|
|
||||||
|
def ask_for_help_v2(chatbot1_memory: List[Dict[str, Any]], chatbot2_memory: List[Dict[str, Any]], max_tokens: int = 8192) -> str:
|
||||||
|
"""
|
||||||
|
This function facilitates the interaction between two chatbots. Chatbot1 asks Chatbot2 for help on a task.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chatbot1_memory (List[Dict[str, Any]]): Memory state of Chatbot1.
|
||||||
|
chatbot2_memory (List[Dict[str, Any]]): Memory state of Chatbot2.
|
||||||
|
max_tokens (int, optional): The maximum number of tokens GPT-4 can handle. Defaults to 8192.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The suggestion provided by Chatbot2 to help Chatbot1 with the task.
|
||||||
|
|
||||||
|
Error handling:
|
||||||
|
If the text exceeds the token limit, an error message is logged, and the function returns an empty string.
|
||||||
|
"""
|
||||||
|
|
||||||
|
logging.info("Ask for help started")
|
||||||
|
|
||||||
|
chat_input1 = chatbot1_memory + [{"role": "user", "content": "Chatbot1, I need help with this task."}]
|
||||||
|
tokens1 = count_tokens(json.dumps(chat_input1), tokenizer)
|
||||||
|
|
||||||
|
if tokens1 >= max_tokens:
|
||||||
|
logging.error("Token limit exceeded for ask_for_help_v2")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
chatbot1_help_request, chatbot1_tokens = chat(chat_input1)
|
||||||
|
chatbot1_memory.append({"role": "assistant", "content": chatbot1_help_request})
|
||||||
|
|
||||||
|
chat_input2 = chatbot2_memory + [{"role": "user", "content": f"Chatbot2, please help me with this: {chatbot1_help_request}"}]
|
||||||
|
tokens2 = count_tokens(json.dumps(chat_input2), tokenizer)
|
||||||
|
|
||||||
|
if tokens2 >= max_tokens:
|
||||||
|
logging.error("Token limit exceeded for ask_for_help_v2")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
chatbot2_suggestion, chatbot2_tokens = chat(chat_input2)
|
||||||
|
chatbot2_memory.append({"role": "assistant", "content": chatbot2_suggestion})
|
||||||
|
|
||||||
|
logging.info(f"Ask for help completed. Chatbot2's suggestion: {chatbot2_suggestion}")
|
||||||
|
|
||||||
|
# Save the chat history to a database (e.g., a graph database like Neo4j for interconnected data)
|
||||||
|
save_chat_history_to_database(chatbot1_memory, chatbot2_memory)
|
||||||
|
|
||||||
|
return chatbot2_suggestion
|
||||||
|
|
||||||
|
|
||||||
|
def collaborative_brainstorm_v2(topic: str, chatbot1_memory: List[Dict[str, Any]], chatbot2_memory: List[Dict[str, Any]], max_tokens: int = 8192) -> List[str]:
|
||||||
|
"""
|
||||||
|
This function enables two chatbots to collaboratively brainstorm ideas on a given topic.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
topic (str): The topic for brainstorming.
|
||||||
|
chatbot1_memory (List[Dict[str, Any]]): Memory state of Chatbot1.
|
||||||
|
chatbot2_memory (List[Dict[str, Any]]): Memory state of Chatbot2.
|
||||||
|
max_tokens (int, optional): The maximum number of tokens GPT-4 can handle. Defaults to 8192.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[str]: A list of brainstormed ideas from both chatbots.
|
||||||
|
|
||||||
|
Error handling:
|
||||||
|
If the text exceeds the token limit, an error message is logged, and the function returns an empty list.
|
||||||
|
"""
|
||||||
|
|
||||||
|
logging.info(f"Collaborative brainstorming started for topic: {topic}")
|
||||||
|
|
||||||
|
ideas = []
|
||||||
|
|
||||||
|
for i in range(3):
|
||||||
|
chat_input1 = chatbot1_memory + [{"role": "user", "content": f"Chatbot1, brainstorm an idea for {topic}"}]
|
||||||
|
tokens1 = count_tokens(json.dumps(chat_input1), tokenizer)
|
||||||
|
|
||||||
|
if tokens1 >= max_tokens:
|
||||||
|
logging.error("Token limit exceeded for collaborative_brainstorm_v2")
|
||||||
|
return []
|
||||||
|
|
||||||
|
chatbot1_idea, chatbot1_tokens = chat(chat_input1)
|
||||||
|
chatbot1_memory.append({"role": "assistant", "content": chatbot1_idea})
|
||||||
|
ideas.append(chatbot1_idea)
|
||||||
|
|
||||||
|
chat_input2 = chatbot2_memory + [{"role": "user", "content": f"Chatbot2, brainstorm an idea for {topic}"}]
|
||||||
|
tokens2 = count_tokens(json.dumps(chat_input2), tokenizer)
|
||||||
|
|
||||||
|
if tokens2 >= max_tokens:
|
||||||
|
logging.error("Token limit exceeded for collaborative_brainstorm_v2")
|
||||||
|
return []
|
||||||
|
|
||||||
|
chatbot2_idea, chatbot2_tokens = chat(chat_input2)
|
||||||
|
chatbot2_memory.append({"role": "assistant", "content": chatbot2_idea})
|
||||||
|
ideas.append(chatbot2_idea)
|
||||||
|
|
||||||
|
logging.info(f"Collaborative brainstorming completed. Ideas: {ideas}")
|
||||||
|
|
||||||
|
# Save the brainstorming session to a database (e.g., a document database like MongoDB for storing complex data structures)
|
||||||
|
save_brainstorming_session_to_database(topic, ideas, chatbot1_memory, chatbot2_memory)
|
||||||
|
|
||||||
|
return ideas
|
||||||
|
|
||||||
|
|
||||||
|
def graph_based_chat_v2(chatbot_memory: List[Dict[str, Any]], user_id: str, user_message: str, graph_database: GraphDatabase, max_tokens: int = 8192) -> str:
|
||||||
|
"""
|
||||||
|
This function allows a chatbot to engage in a conversation with a user, utilizing a graph database to provide insights
|
||||||
|
and connections between users, keywords, and topics.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chatbot_memory (List[Dict[str, Any]]): Memory state of the chatbot.
|
||||||
|
user_id (str): The unique identifier for the user.
|
||||||
|
user_message (str): The message from the user.
|
||||||
|
graph_database (GraphDatabase): The graph database containing connections between users, keywords, topics, etc.
|
||||||
|
max_tokens (int, optional): The maximum number of tokens GPT-4 can handle. Defaults to 8192.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The chatbot's response to the user's message.
|
||||||
|
|
||||||
|
Error handling:
|
||||||
|
If the text exceeds the token limit, an error message is logged, and the function returns an empty string.
|
||||||
|
"""
|
||||||
|
|
||||||
|
logging.info(f"Received message from user {user_id}: {user_message}")
|
||||||
|
|
||||||
|
# Update the graph database with user's message
|
||||||
|
update_graph_database(user_id, user_message, graph_database)
|
||||||
|
|
||||||
|
# Retrieve insights from the graph database
|
||||||
|
insights = get_insights(graph_database)
|
||||||
|
|
||||||
|
chat_input = chatbot_memory + [{"role": "user", "content": f"{user_message}\nInsights: {insights}"}]
|
||||||
|
tokens = count_tokens(json.dumps(chat_input), tokenizer)
|
||||||
|
|
||||||
|
if tokens >= max_tokens:
|
||||||
|
logging.error("Token limit exceeded for graph_based_chat_v2")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
chatbot_response, chatbot_tokens = chat(chat_input)
|
||||||
|
chatbot_memory.append({"role": "assistant", "content": chatbot_response})
|
||||||
|
|
||||||
|
logging.info(f"Chatbot response to user {user_id}: {chatbot_response}")
|
||||||
|
|
||||||
|
# Save the chat conversation to a database (e.g., a relational database like MySQL for structured data)
|
||||||
|
save_chat_conversation_to_database(user_id, user_message, chatbot_response)
|
||||||
|
|
||||||
|
return chatbot_response
|
||||||
|
|
||||||
|
|
||||||
|
def multi_platform_chat_v2(platform: str, chatbot_memory: List[Dict[str, Any]], user_id: str, user_message: str, max_tokens: int = 8192) -> str:
|
||||||
|
"""
|
||||||
|
This function allows a chatbot to engage in a conversation with a user on various platforms such as
|
||||||
|
WhatsApp, Snapchat, Facebook, Twitter, etc.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
platform (str): The platform on which the chat is taking place (e.g., "WhatsApp", "Facebook").
|
||||||
|
chatbot_memory (List[Dict[str, Any]]): Memory state of the chatbot.
|
||||||
|
user_id (str): The unique identifier for the user.
|
||||||
|
user_message (str): The message from the user.
|
||||||
|
max_tokens (int, optional): The maximum number of tokens GPT-4 can handle. Defaults to 8192.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The chatbot's response to the user's message.
|
||||||
|
|
||||||
|
Error handling:
|
||||||
|
If the text exceeds the token limit, an error message is logged, and the function returns an empty string.
|
||||||
|
"""
|
||||||
|
|
||||||
|
logging.info(f"Received message from user {user_id} on {platform}: {user_message}")
|
||||||
|
|
||||||
|
chat_input = chatbot_memory + [{"role": "user", "content": f"Platform: {platform}\nUser message: {user_message}"}]
|
||||||
|
tokens = count_tokens(json.dumps(chat_input), tokenizer)
|
||||||
|
|
||||||
|
if tokens >= max_tokens:
|
||||||
|
logging.error("Token limit exceeded for multi_platform_chat_v2")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
chatbot_response, chatbot_tokens = chat(chat_input)
|
||||||
|
chatbot_memory.append({"role": "assistant", "content": chatbot_response})
|
||||||
|
|
||||||
|
logging.info(f"Chatbot response to user {user_id} on {platform}: {chatbot_response}")
|
||||||
|
|
||||||
|
# Save the chat conversation to a database (e.g., a document-based database like MongoDB for unstructured data)
|
||||||
|
save_chat_conversation_to_database(user_id, platform, user_message, chatbot_response)
|
||||||
|
|
||||||
|
return chatbot_response
|
||||||
|
|
||||||
|
|
||||||
|
def agent_swapping_v2(task_prompt: str, agents_memory: List[Dict[str, Any]], max_tokens: int = 8192) -> str:
|
||||||
|
"""
|
||||||
|
This function allows multiple agents to collaboratively solve a task by swapping in and out when
|
||||||
|
their individual knowledge is insufficient.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_prompt (str): The task to be solved.
|
||||||
|
agents_memory (List[Dict[str, Any]]): List of memory states for each agent.
|
||||||
|
max_tokens (int, optional): The maximum number of tokens GPT-4 can handle. Defaults to 8192.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The final solution to the task.
|
||||||
|
|
||||||
|
Error handling:
|
||||||
|
If the text exceeds the token limit, an error message is logged, and the function returns an empty string.
|
||||||
|
"""
|
||||||
|
|
||||||
|
current_agent_index = 0
|
||||||
|
current_agent_memory = agents_memory[current_agent_index]
|
||||||
|
input_messages = current_agent_memory + [{"role": "user", "content": f"Task: {task_prompt}"}]
|
||||||
|
tokens = count_tokens(json.dumps(input_messages), tokenizer)
|
||||||
|
|
||||||
|
if tokens >= max_tokens:
|
||||||
|
logging.error("Token limit exceeded for agent_swapping_v2")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
partial_solution, remaining_task = chat(input_messages)
|
||||||
|
|
||||||
|
while remaining_task:
|
||||||
|
current_agent_index = (current_agent_index + 1) % len(agents_memory)
|
||||||
|
current_agent_memory = agents_memory[current_agent_index]
|
||||||
|
input_messages = current_agent_memory + [{"role": "user", "content": f"Remaining task: {remaining_task}"}]
|
||||||
|
tokens = count_tokens(json.dumps(input_messages), tokenizer)
|
||||||
|
|
||||||
|
if tokens >= max_tokens:
|
||||||
|
logging.error("Token limit exceeded for agent_swapping_v2")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
next_partial_solution, remaining_task = chat(input_messages)
|
||||||
|
partial_solution += next_partial_solution
|
||||||
|
|
||||||
|
return partial_solution
|
||||||
|
|
||||||
|
|
||||||
|
def multi_agent_voting_v2(task_prompt: str, agents_memory: List[Dict[str, Any]], max_tokens: int = 8192) -> str:
|
||||||
|
"""
|
||||||
|
This function allows multiple agents to collaboratively solve a task by proposing solutions and
|
||||||
|
voting on the best one.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_prompt (str): The task to be solved.
|
||||||
|
agents_memory (List[Dict[str, Any]]): List of memory states for each agent.
|
||||||
|
max_tokens (int, optional): The maximum number of tokens GPT-4 can handle. Defaults to 8192.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The final solution to the task.
|
||||||
|
|
||||||
|
Error handling:
|
||||||
|
If the text exceeds the token limit, an error message is logged, and the function returns an empty string.
|
||||||
|
"""
|
||||||
|
|
||||||
|
proposed_solutions = []
|
||||||
|
for agent_memory in agents_memory:
|
||||||
|
input_messages = agent_memory + [{"role": "user", "content": f"Propose a solution for: {task_prompt}"}]
|
||||||
|
tokens = count_tokens(json.dumps(input_messages), tokenizer)
|
||||||
|
|
||||||
|
if tokens >= max_tokens:
|
||||||
|
logging.error("Token limit exceeded for multi_agent_voting_v2")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
proposed_solution, _ = chat(input_messages)
|
||||||
|
proposed_solutions.append(proposed_solution.strip())
|
||||||
|
|
||||||
|
input_messages = [{"role": "system", "content": "You are an AI agent. Vote on the best solution from the following options:"}] + [{"role": "assistant", "content": option} for option in proposed_solutions]
|
||||||
|
tokens = count_tokens(json.dumps(input_messages), tokenizer)
|
||||||
|
|
||||||
|
if tokens >= max_tokens:
|
||||||
|
logging.error("Token limit exceeded for multi_agent_voting_v2")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
winning_solution, _ = chat(input_messages + [{"role": "user", "content": "Which solution is the best?"}])
|
||||||
|
return winning_solution.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def multi_agent_brainstorming_v2(topic: str, agents_memory: List[Dict[str, Any]], max_tokens: int = 8192) -> List[str]:
|
||||||
|
"""
|
||||||
|
This function allows multiple agents to collaboratively brainstorm ideas on a given topic.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
topic (str): The topic for brainstorming.
|
||||||
|
agents_memory (List[Dict[str, Any]]): List of memory states for each agent.
|
||||||
|
max_tokens (int, optional): The maximum number of tokens GPT-4 can handle. Defaults to 8192.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[str]: List of brainstormed ideas.
|
||||||
|
|
||||||
|
Error handling:
|
||||||
|
If the text exceeds the token limit, an error message is logged, and the function returns an empty list.
|
||||||
|
"""
|
||||||
|
|
||||||
|
ideas = []
|
||||||
|
|
||||||
|
for agent_memory in agents_memory:
|
||||||
|
input_messages = agent_memory + [{"role": "user", "content": f"Brainstorm an idea for: {topic}"}]
|
||||||
|
tokens = count_tokens(json.dumps(input_messages), tokenizer)
|
||||||
|
|
||||||
|
if tokens >= max_tokens:
|
||||||
|
logging.error("Token limit exceeded for multi_agent_brainstorming_v2")
|
||||||
|
return []
|
||||||
|
|
||||||
|
idea, _ = chat(input_messages)
|
||||||
|
ideas.append(idea.strip())
|
||||||
|
|
||||||
|
return ideas
|
||||||
|
|
||||||
|
|
||||||
|
def multi_agent_emotion_analysis_v2(text: str, agents_memory: List[Dict[str, Any]], max_tokens: int = 8192) -> Dict[str, float]:
|
||||||
|
"""
|
||||||
|
This function allows multiple agents to perform emotion analysis on a given text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text (str): The text to perform emotion analysis on.
|
||||||
|
agents_memory (List[Dict[str, Any]]): List of memory states for each agent.
|
||||||
|
max_tokens (int, optional): The maximum number of tokens GPT-4 can handle. Defaults to 8192.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, float]: A dictionary containing emotion scores for the text.
|
||||||
|
|
||||||
|
Error handling:
|
||||||
|
If the text exceeds the token limit, an error message is logged, and the function returns an empty dictionary.
|
||||||
|
"""
|
||||||
|
|
||||||
|
emotion_scores = defaultdict(float)
|
||||||
|
|
||||||
|
for agent_memory in agents_memory:
|
||||||
|
input_messages = agent_memory + [{"role": "user", "content": f"Analyze the emotions in this text: {text}"}]
|
||||||
|
tokens = count_tokens(json.dumps(input_messages), tokenizer)
|
||||||
|
|
||||||
|
if tokens >= max_tokens:
|
||||||
|
logging.error("Token limit exceeded for multi_agent_emotion_analysis_v2")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
emotion_analysis, _ = chat(input_messages)
|
||||||
|
parsed_scores = json.loads(emotion_analysis.strip())
|
||||||
|
|
||||||
|
for emotion, score in parsed_scores.items():
|
||||||
|
emotion_scores[emotion] += score
|
||||||
|
|
||||||
|
for emotion in emotion_scores:
|
||||||
|
emotion_scores[emotion] /= len(agents_memory)
|
||||||
|
|
||||||
|
return emotion_scores
|
||||||
|
|
||||||
|
def swarm_intelligence(task_prompt, agents_memory):
|
||||||
|
subtasks = generate_tasks(task_prompt)
|
||||||
|
results = []
|
||||||
|
for subtask in subtasks:
|
||||||
|
agent_votes = []
|
||||||
|
for agent_memory in agents_memory:
|
||||||
|
agent_vote, _ = chat(agent_memory + [{"role": "user", "content": f"Propose a solution for: {subtask}"}])
|
||||||
|
agent_votes.append(agent_vote.strip())
|
||||||
|
most_common_solution = max(set(agent_votes), key=agent_votes.count)
|
||||||
|
results.append(most_common_solution)
|
||||||
|
return results
|
@ -0,0 +1,95 @@
|
|||||||
|
# General
|
||||||
|
import os
|
||||||
|
import pandas as pd
|
||||||
|
from langchain.experimental.autonomous_agents.autogpt.agent import AutoGPT
|
||||||
|
|
||||||
|
from langchain.chat_models import ChatOpenAI
|
||||||
|
from langchain.agents.agent_toolkits.pandas.base import create_pandas_dataframe_agent
|
||||||
|
from langchain.docstore.document import Document
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import nest_asyncio
|
||||||
|
|
||||||
|
# Tools
|
||||||
|
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from typing import Optional
|
||||||
|
from langchain.agents import tool
|
||||||
|
|
||||||
|
from langchain.tools.file_management.read import ReadFileTool
|
||||||
|
from langchain.tools.file_management.write import WriteFileTool
|
||||||
|
from langchain.tools import BaseTool, DuckDuckGoSearchRun
|
||||||
|
|
||||||
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||||
|
from pydantic import Field
|
||||||
|
from langchain.chains.qa_with_sources.loading import load_qa_with_sources_chain, BaseCombineDocumentsChain
|
||||||
|
|
||||||
|
# Memory
|
||||||
|
import faiss
|
||||||
|
from langchain.vectorstores import FAISS
|
||||||
|
from langchain.docstore import InMemoryDocstore
|
||||||
|
from langchain.embeddings import OpenAIEmbeddings
|
||||||
|
|
||||||
|
from langchain.tools.human.tool import HumanInputRun
|
||||||
|
# from swarms.agents.workers.auto_agent import
|
||||||
|
from swarms.agents.workers.visual_agent import multimodal_agent_tool
|
||||||
|
from swarms.tools.main import Terminal, CodeWriter, CodeEditor, process_csv, WebpageQATool
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class WorkerAgent:
|
||||||
|
def __init__(self, objective: str, api_key: str):
|
||||||
|
self.objective = objective
|
||||||
|
self.api_key = api_key
|
||||||
|
self.worker = self.create_agent_worker()
|
||||||
|
|
||||||
|
def create_agent_worker(self):
|
||||||
|
os.environ['OPENAI_API_KEY'] = self.api_key
|
||||||
|
|
||||||
|
llm = ChatOpenAI(model_name="gpt-4", temperature=1.0)
|
||||||
|
embeddings_model = OpenAIEmbeddings()
|
||||||
|
embedding_size = 1536
|
||||||
|
index = faiss.IndexFlatL2(embedding_size)
|
||||||
|
vectorstore = FAISS(embeddings_model.embed_query, index, InMemoryDocstore({}), {})
|
||||||
|
|
||||||
|
query_website_tool = WebpageQATool(qa_chain=load_qa_with_sources_chain(llm))
|
||||||
|
web_search = DuckDuckGoSearchRun()
|
||||||
|
|
||||||
|
tools = [
|
||||||
|
web_search,
|
||||||
|
WriteFileTool(root_dir="./data"),
|
||||||
|
ReadFileTool(root_dir="./data"),
|
||||||
|
|
||||||
|
multimodal_agent_tool,
|
||||||
|
process_csv,
|
||||||
|
query_website_tool,
|
||||||
|
Terminal,
|
||||||
|
|
||||||
|
|
||||||
|
CodeWriter,
|
||||||
|
CodeEditor
|
||||||
|
]
|
||||||
|
|
||||||
|
agent_worker = AutoGPT.from_llm_and_tools(
|
||||||
|
ai_name="WorkerX",
|
||||||
|
ai_role="Assistant",
|
||||||
|
tools=tools,
|
||||||
|
llm=llm,
|
||||||
|
memory=vectorstore.as_retriever(search_kwargs={"k": 8}),
|
||||||
|
human_in_the_loop=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
agent_worker.chain.verbose = True
|
||||||
|
|
||||||
|
return agent_worker
|
||||||
|
|
||||||
|
# objective = "Your objective here"
|
||||||
|
# api_key = "Your OpenAI API key here"
|
||||||
|
|
||||||
|
# worker_agent = WorkerAgent(objective, api_key)
|
||||||
|
|
||||||
|
|
||||||
|
# objective = "Your objective here"
|
||||||
|
|
||||||
|
|
||||||
|
# worker_agent = WorkerAgent(objective)
|
@ -0,0 +1,99 @@
|
|||||||
|
from langchain import OpenAI, LLMChain, PromptTemplate
|
||||||
|
from langchain.memory import ConversationBufferWindowMemory
|
||||||
|
|
||||||
|
def initialize_chain(instructions, memory=None):
|
||||||
|
if memory is None:
|
||||||
|
memory = ConversationBufferWindowMemory()
|
||||||
|
memory.ai_prefix = "Assistant"
|
||||||
|
|
||||||
|
template = f"""
|
||||||
|
Instructions: {instructions}
|
||||||
|
{{{memory.memory_key}}}
|
||||||
|
Human: {{human_input}}
|
||||||
|
Assistant:"""
|
||||||
|
|
||||||
|
prompt = PromptTemplate(
|
||||||
|
input_variables=["history", "human_input"], template=template
|
||||||
|
)
|
||||||
|
|
||||||
|
chain = LLMChain(
|
||||||
|
llm=OpenAI(temperature=0),
|
||||||
|
prompt=prompt,
|
||||||
|
verbose=True,
|
||||||
|
memory=ConversationBufferWindowMemory(),
|
||||||
|
)
|
||||||
|
return chain
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_meta_chain():
|
||||||
|
meta_template = """
|
||||||
|
Assistant has just had the below interactions with a User. Assistant followed their "Instructions" closely. Your job is to critique the Assistant's performance and then revise the Instructions so that Assistant would quickly and correctly respond in the future.
|
||||||
|
|
||||||
|
####
|
||||||
|
|
||||||
|
{chat_history}
|
||||||
|
|
||||||
|
####
|
||||||
|
|
||||||
|
Please reflect on these interactions.
|
||||||
|
|
||||||
|
You should first critique Assistant's performance. What could Assistant have done better? What should the Assistant remember about this user? Are there things this user always wants? Indicate this with "Critique: ...".
|
||||||
|
|
||||||
|
You should next revise the Instructions so that Assistant would quickly and correctly respond in the future. Assistant's goal is to satisfy the user in as few interactions as possible. Assistant will only see the new Instructions, not the interaction history, so anything important must be summarized in the Instructions. Don't forget any important details in the current Instructions! Indicate the new Instructions by "Instructions: ...".
|
||||||
|
"""
|
||||||
|
|
||||||
|
meta_prompt = PromptTemplate(
|
||||||
|
input_variables=["chat_history"], template=meta_template
|
||||||
|
)
|
||||||
|
|
||||||
|
meta_chain = LLMChain(
|
||||||
|
llm=OpenAI(temperature=0),
|
||||||
|
prompt=meta_prompt,
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
return meta_chain
|
||||||
|
|
||||||
|
|
||||||
|
def get_chat_history(chain_memory):
|
||||||
|
memory_key = chain_memory.memory_key
|
||||||
|
chat_history = chain_memory.load_memory_variables(memory_key)[memory_key]
|
||||||
|
return chat_history
|
||||||
|
|
||||||
|
|
||||||
|
def get_new_instructions(meta_output):
|
||||||
|
delimiter = "Instructions: "
|
||||||
|
new_instructions = meta_output[meta_output.find(delimiter) + len(delimiter) :]
|
||||||
|
return new_instructions
|
||||||
|
|
||||||
|
def meta_agent(task, max_iters=3, max_meta_iters=5):
|
||||||
|
failed_phrase = "task failed"
|
||||||
|
success_phrase = "task succeeded"
|
||||||
|
key_phrases = [success_phrase, failed_phrase]
|
||||||
|
|
||||||
|
instructions = "None"
|
||||||
|
for i in range(max_meta_iters):
|
||||||
|
print(f"[Episode {i+1}/{max_meta_iters}]")
|
||||||
|
chain = initialize_chain(instructions, memory=None)
|
||||||
|
output = chain.predict(human_input=task)
|
||||||
|
for j in range(max_iters):
|
||||||
|
print(f"(Step {j+1}/{max_iters})")
|
||||||
|
print(f"Assistant: {output}")
|
||||||
|
print(f"Human: ")
|
||||||
|
human_input = input()
|
||||||
|
if any(phrase in human_input.lower() for phrase in key_phrases):
|
||||||
|
break
|
||||||
|
output = chain.predict(human_input=human_input)
|
||||||
|
if success_phrase in human_input.lower():
|
||||||
|
print(f"You succeeded! Thanks for playing!")
|
||||||
|
return
|
||||||
|
meta_chain = initialize_meta_chain()
|
||||||
|
meta_output = meta_chain.predict(chat_history=get_chat_history(chain.memory))
|
||||||
|
print(f"Feedback: {meta_output}")
|
||||||
|
instructions = get_new_instructions(meta_output)
|
||||||
|
print(f"New Instructions: {instructions}")
|
||||||
|
print("\n" + "#" * 80 + "\n")
|
||||||
|
print(f"You failed! Thanks for playing!")
|
||||||
|
|
||||||
|
|
||||||
|
task = "Provide a systematic argument for why we should always eat pasta with olives."
|
||||||
|
meta_agent(task)
|
After Width: | Height: | Size: 256 KiB |
After Width: | Height: | Size: 286 KiB |
After Width: | Height: | Size: 555 KiB |
After Width: | Height: | Size: 120 KiB |
After Width: | Height: | Size: 373 KiB |
After Width: | Height: | Size: 354 KiB |
After Width: | Height: | Size: 472 KiB |
After Width: | Height: | Size: 456 KiB |
@ -0,0 +1,146 @@
|
|||||||
|
# IDE
|
||||||
|
.idea/
|
||||||
|
.vscode/
|
||||||
|
|
||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
pip-wheel-metadata/
|
||||||
|
share/python-wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
*.py,cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
db.sqlite3-journal
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
.python-version
|
||||||
|
|
||||||
|
# pipenv
|
||||||
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||||
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||||
|
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||||
|
# install all needed dependencies.
|
||||||
|
#Pipfile.lock
|
||||||
|
|
||||||
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
||||||
|
__pypackages__/
|
||||||
|
|
||||||
|
# Celery stuff
|
||||||
|
celerybeat-schedule
|
||||||
|
celerybeat.pid
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
||||||
|
|
||||||
|
# vscode
|
||||||
|
.vscode/
|
||||||
|
output/
|
||||||
|
outputs/
|
||||||
|
subs/
|
||||||
|
logs/
|
||||||
|
|
||||||
|
grounding/config/configs
|
||||||
|
grounding/version.py
|
||||||
|
|
||||||
|
vis/
|
||||||
|
tmp/
|
@ -0,0 +1,201 @@
|
|||||||
|
Apache License
|
||||||
|
Version 2.0, January 2004
|
||||||
|
http://www.apache.org/licenses/
|
||||||
|
|
||||||
|
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||||
|
|
||||||
|
1. Definitions.
|
||||||
|
|
||||||
|
"License" shall mean the terms and conditions for use, reproduction,
|
||||||
|
and distribution as defined by Sections 1 through 9 of this document.
|
||||||
|
|
||||||
|
"Licensor" shall mean the copyright owner or entity authorized by
|
||||||
|
the copyright owner that is granting the License.
|
||||||
|
|
||||||
|
"Legal Entity" shall mean the union of the acting entity and all
|
||||||
|
other entities that control, are controlled by, or are under common
|
||||||
|
control with that entity. For the purposes of this definition,
|
||||||
|
"control" means (i) the power, direct or indirect, to cause the
|
||||||
|
direction or management of such entity, whether by contract or
|
||||||
|
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||||
|
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||||
|
|
||||||
|
"You" (or "Your") shall mean an individual or Legal Entity
|
||||||
|
exercising permissions granted by this License.
|
||||||
|
|
||||||
|
"Source" form shall mean the preferred form for making modifications,
|
||||||
|
including but not limited to software source code, documentation
|
||||||
|
source, and configuration files.
|
||||||
|
|
||||||
|
"Object" form shall mean any form resulting from mechanical
|
||||||
|
transformation or translation of a Source form, including but
|
||||||
|
not limited to compiled object code, generated documentation,
|
||||||
|
and conversions to other media types.
|
||||||
|
|
||||||
|
"Work" shall mean the work of authorship, whether in Source or
|
||||||
|
Object form, made available under the License, as indicated by a
|
||||||
|
copyright notice that is included in or attached to the work
|
||||||
|
(an example is provided in the Appendix below).
|
||||||
|
|
||||||
|
"Derivative Works" shall mean any work, whether in Source or Object
|
||||||
|
form, that is based on (or derived from) the Work and for which the
|
||||||
|
editorial revisions, annotations, elaborations, or other modifications
|
||||||
|
represent, as a whole, an original work of authorship. For the purposes
|
||||||
|
of this License, Derivative Works shall not include works that remain
|
||||||
|
separable from, or merely link (or bind by name) to the interfaces of,
|
||||||
|
the Work and Derivative Works thereof.
|
||||||
|
|
||||||
|
"Contribution" shall mean any work of authorship, including
|
||||||
|
the original version of the Work and any modifications or additions
|
||||||
|
to that Work or Derivative Works thereof, that is intentionally
|
||||||
|
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||||
|
or by an individual or Legal Entity authorized to submit on behalf of
|
||||||
|
the copyright owner. For the purposes of this definition, "submitted"
|
||||||
|
means any form of electronic, verbal, or written communication sent
|
||||||
|
to the Licensor or its representatives, including but not limited to
|
||||||
|
communication on electronic mailing lists, source code control systems,
|
||||||
|
and issue tracking systems that are managed by, or on behalf of, the
|
||||||
|
Licensor for the purpose of discussing and improving the Work, but
|
||||||
|
excluding communication that is conspicuously marked or otherwise
|
||||||
|
designated in writing by the copyright owner as "Not a Contribution."
|
||||||
|
|
||||||
|
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||||
|
on behalf of whom a Contribution has been received by Licensor and
|
||||||
|
subsequently incorporated within the Work.
|
||||||
|
|
||||||
|
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
copyright license to reproduce, prepare Derivative Works of,
|
||||||
|
publicly display, publicly perform, sublicense, and distribute the
|
||||||
|
Work and such Derivative Works in Source or Object form.
|
||||||
|
|
||||||
|
3. Grant of Patent License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
(except as stated in this section) patent license to make, have made,
|
||||||
|
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||||
|
where such license applies only to those patent claims licensable
|
||||||
|
by such Contributor that are necessarily infringed by their
|
||||||
|
Contribution(s) alone or by combination of their Contribution(s)
|
||||||
|
with the Work to which such Contribution(s) was submitted. If You
|
||||||
|
institute patent litigation against any entity (including a
|
||||||
|
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||||
|
or a Contribution incorporated within the Work constitutes direct
|
||||||
|
or contributory patent infringement, then any patent licenses
|
||||||
|
granted to You under this License for that Work shall terminate
|
||||||
|
as of the date such litigation is filed.
|
||||||
|
|
||||||
|
4. Redistribution. You may reproduce and distribute copies of the
|
||||||
|
Work or Derivative Works thereof in any medium, with or without
|
||||||
|
modifications, and in Source or Object form, provided that You
|
||||||
|
meet the following conditions:
|
||||||
|
|
||||||
|
(a) You must give any other recipients of the Work or
|
||||||
|
Derivative Works a copy of this License; and
|
||||||
|
|
||||||
|
(b) You must cause any modified files to carry prominent notices
|
||||||
|
stating that You changed the files; and
|
||||||
|
|
||||||
|
(c) You must retain, in the Source form of any Derivative Works
|
||||||
|
that You distribute, all copyright, patent, trademark, and
|
||||||
|
attribution notices from the Source form of the Work,
|
||||||
|
excluding those notices that do not pertain to any part of
|
||||||
|
the Derivative Works; and
|
||||||
|
|
||||||
|
(d) If the Work includes a "NOTICE" text file as part of its
|
||||||
|
distribution, then any Derivative Works that You distribute must
|
||||||
|
include a readable copy of the attribution notices contained
|
||||||
|
within such NOTICE file, excluding those notices that do not
|
||||||
|
pertain to any part of the Derivative Works, in at least one
|
||||||
|
of the following places: within a NOTICE text file distributed
|
||||||
|
as part of the Derivative Works; within the Source form or
|
||||||
|
documentation, if provided along with the Derivative Works; or,
|
||||||
|
within a display generated by the Derivative Works, if and
|
||||||
|
wherever such third-party notices normally appear. The contents
|
||||||
|
of the NOTICE file are for informational purposes only and
|
||||||
|
do not modify the License. You may add Your own attribution
|
||||||
|
notices within Derivative Works that You distribute, alongside
|
||||||
|
or as an addendum to the NOTICE text from the Work, provided
|
||||||
|
that such additional attribution notices cannot be construed
|
||||||
|
as modifying the License.
|
||||||
|
|
||||||
|
You may add Your own copyright statement to Your modifications and
|
||||||
|
may provide additional or different license terms and conditions
|
||||||
|
for use, reproduction, or distribution of Your modifications, or
|
||||||
|
for any such Derivative Works as a whole, provided Your use,
|
||||||
|
reproduction, and distribution of the Work otherwise complies with
|
||||||
|
the conditions stated in this License.
|
||||||
|
|
||||||
|
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||||
|
any Contribution intentionally submitted for inclusion in the Work
|
||||||
|
by You to the Licensor shall be under the terms and conditions of
|
||||||
|
this License, without any additional terms or conditions.
|
||||||
|
Notwithstanding the above, nothing herein shall supersede or modify
|
||||||
|
the terms of any separate license agreement you may have executed
|
||||||
|
with Licensor regarding such Contributions.
|
||||||
|
|
||||||
|
6. Trademarks. This License does not grant permission to use the trade
|
||||||
|
names, trademarks, service marks, or product names of the Licensor,
|
||||||
|
except as required for reasonable and customary use in describing the
|
||||||
|
origin of the Work and reproducing the content of the NOTICE file.
|
||||||
|
|
||||||
|
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||||
|
agreed to in writing, Licensor provides the Work (and each
|
||||||
|
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
implied, including, without limitation, any warranties or conditions
|
||||||
|
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||||
|
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||||
|
appropriateness of using or redistributing the Work and assume any
|
||||||
|
risks associated with Your exercise of permissions under this License.
|
||||||
|
|
||||||
|
8. Limitation of Liability. In no event and under no legal theory,
|
||||||
|
whether in tort (including negligence), contract, or otherwise,
|
||||||
|
unless required by applicable law (such as deliberate and grossly
|
||||||
|
negligent acts) or agreed to in writing, shall any Contributor be
|
||||||
|
liable to You for damages, including any direct, indirect, special,
|
||||||
|
incidental, or consequential damages of any character arising as a
|
||||||
|
result of this License or out of the use or inability to use the
|
||||||
|
Work (including but not limited to damages for loss of goodwill,
|
||||||
|
work stoppage, computer failure or malfunction, or any and all
|
||||||
|
other commercial damages or losses), even if such Contributor
|
||||||
|
has been advised of the possibility of such damages.
|
||||||
|
|
||||||
|
9. Accepting Warranty or Additional Liability. While redistributing
|
||||||
|
the Work or Derivative Works thereof, You may choose to offer,
|
||||||
|
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||||
|
or other liability obligations and/or rights consistent with this
|
||||||
|
License. However, in accepting such obligations, You may act only
|
||||||
|
on Your own behalf and on Your sole responsibility, not on behalf
|
||||||
|
of any other Contributor, and only if You agree to indemnify,
|
||||||
|
defend, and hold each Contributor harmless for any liability
|
||||||
|
incurred by, or claims asserted against, such Contributor by reason
|
||||||
|
of your accepting any such warranty or additional liability.
|
||||||
|
|
||||||
|
END OF TERMS AND CONDITIONS
|
||||||
|
|
||||||
|
APPENDIX: How to apply the Apache License to your work.
|
||||||
|
|
||||||
|
To apply the Apache License to your work, attach the following
|
||||||
|
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||||
|
replaced with your own identifying information. (Don't include
|
||||||
|
the brackets!) The text should be enclosed in the appropriate
|
||||||
|
comment syntax for the file format. We also recommend that a
|
||||||
|
file or class name and description of purpose be included on the
|
||||||
|
same "printed page" as the copyright notice for easier
|
||||||
|
identification within third-party archives.
|
||||||
|
|
||||||
|
Copyright 2023 - present, IDEA Research.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
@ -0,0 +1,327 @@
|
|||||||
|
<div align="center">
|
||||||
|
<img src="./.asset/grounding_dino_logo.png" width="30%">
|
||||||
|
</div>
|
||||||
|
|
||||||
|
# :sauropod: Grounding DINO
|
||||||
|
|
||||||
|
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/grounding-dino-marrying-dino-with-grounded/zero-shot-object-detection-on-mscoco)](https://paperswithcode.com/sota/zero-shot-object-detection-on-mscoco?p=grounding-dino-marrying-dino-with-grounded) [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/grounding-dino-marrying-dino-with-grounded/zero-shot-object-detection-on-odinw)](https://paperswithcode.com/sota/zero-shot-object-detection-on-odinw?p=grounding-dino-marrying-dino-with-grounded) \
|
||||||
|
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/grounding-dino-marrying-dino-with-grounded/object-detection-on-coco-minival)](https://paperswithcode.com/sota/object-detection-on-coco-minival?p=grounding-dino-marrying-dino-with-grounded) [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/grounding-dino-marrying-dino-with-grounded/object-detection-on-coco)](https://paperswithcode.com/sota/object-detection-on-coco?p=grounding-dino-marrying-dino-with-grounded)
|
||||||
|
|
||||||
|
|
||||||
|
**[IDEA-CVR, IDEA-Research](https://github.com/IDEA-Research)**
|
||||||
|
|
||||||
|
[Shilong Liu](http://www.lsl.zone/), [Zhaoyang Zeng](https://scholar.google.com/citations?user=U_cvvUwAAAAJ&hl=zh-CN&oi=ao), [Tianhe Ren](https://rentainhe.github.io/), [Feng Li](https://scholar.google.com/citations?user=ybRe9GcAAAAJ&hl=zh-CN), [Hao Zhang](https://scholar.google.com/citations?user=B8hPxMQAAAAJ&hl=zh-CN), [Jie Yang](https://github.com/yangjie-cv), [Chunyuan Li](https://scholar.google.com/citations?user=Zd7WmXUAAAAJ&hl=zh-CN&oi=ao), [Jianwei Yang](https://jwyang.github.io/), [Hang Su](https://scholar.google.com/citations?hl=en&user=dxN1_X0AAAAJ&view_op=list_works&sortby=pubdate), [Jun Zhu](https://scholar.google.com/citations?hl=en&user=axsP38wAAAAJ), [Lei Zhang](https://www.leizhang.org/)<sup>:email:</sup>.
|
||||||
|
|
||||||
|
|
||||||
|
[[`Paper`](https://arxiv.org/abs/2303.05499)] [[`Demo`](https://huggingface.co/spaces/ShilongLiu/Grounding_DINO_demo)] [[`BibTex`](#black_nib-citation)]
|
||||||
|
|
||||||
|
|
||||||
|
PyTorch implementation and pretrained models for Grounding DINO. For details, see the paper **[Grounding DINO: Marrying DINO with Grounded Pre-Training for Open-Set Object Detection](https://arxiv.org/abs/2303.05499)**.
|
||||||
|
|
||||||
|
## :sun_with_face: Helpful Tutorial
|
||||||
|
|
||||||
|
- :grapes: [[Read our arXiv Paper](https://arxiv.org/abs/2303.05499)]
|
||||||
|
- :apple: [[Watch our simple introduction video on YouTube](https://youtu.be/wxWDt5UiwY8)]
|
||||||
|
- :blossom: [[Try the Colab Demo](https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/zero-shot-object-detection-with-grounding-dino.ipynb)]
|
||||||
|
- :sunflower: [[Try our Official Huggingface Demo](https://huggingface.co/spaces/ShilongLiu/Grounding_DINO_demo)]
|
||||||
|
- :maple_leaf: [[Watch the Step by Step Tutorial about GroundingDINO by Roboflow AI](https://youtu.be/cMa77r3YrDk)]
|
||||||
|
- :mushroom: [[GroundingDINO: Automated Dataset Annotation and Evaluation by Roboflow AI](https://youtu.be/C4NqaRBz_Kw)]
|
||||||
|
- :hibiscus: [[Accelerate Image Annotation with SAM and GroundingDINO by Roboflow AI](https://youtu.be/oEQYStnF2l8)]
|
||||||
|
- :white_flower: [[Autodistill: Train YOLOv8 with ZERO Annotations based on Grounding-DINO and Grounded-SAM by Roboflow AI](https://github.com/autodistill/autodistill)]
|
||||||
|
|
||||||
|
<!-- Grounding DINO Methods |
|
||||||
|
[![arXiv](https://img.shields.io/badge/arXiv-2303.05499-b31b1b.svg)](https://arxiv.org/abs/2303.05499)
|
||||||
|
[![YouTube](https://badges.aleen42.com/src/youtube.svg)](https://youtu.be/wxWDt5UiwY8) -->
|
||||||
|
|
||||||
|
<!-- Grounding DINO Demos |
|
||||||
|
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/zero-shot-object-detection-with-grounding-dino.ipynb) -->
|
||||||
|
<!-- [![YouTube](https://badges.aleen42.com/src/youtube.svg)](https://youtu.be/cMa77r3YrDk)
|
||||||
|
[![HuggingFace space](https://img.shields.io/badge/🤗-HuggingFace%20Space-cyan.svg)](https://huggingface.co/spaces/ShilongLiu/Grounding_DINO_demo)
|
||||||
|
[![YouTube](https://badges.aleen42.com/src/youtube.svg)](https://youtu.be/oEQYStnF2l8)
|
||||||
|
[![YouTube](https://badges.aleen42.com/src/youtube.svg)](https://youtu.be/C4NqaRBz_Kw) -->
|
||||||
|
|
||||||
|
## :sparkles: Highlight Projects
|
||||||
|
|
||||||
|
- [DetGPT: Detect What You Need via Reasoning](https://github.com/OptimalScale/DetGPT)
|
||||||
|
- [Grounded-SAM: Marrying Grounding DINO with Segment Anything](https://github.com/IDEA-Research/Grounded-Segment-Anything)
|
||||||
|
- [Grounding DINO with Stable Diffusion](demo/image_editing_with_groundingdino_stablediffusion.ipynb)
|
||||||
|
- [Grounding DINO with GLIGEN for Controllable Image Editing](demo/image_editing_with_groundingdino_gligen.ipynb)
|
||||||
|
- [OpenSeeD: A Simple and Strong Openset Segmentation Model](https://github.com/IDEA-Research/OpenSeeD)
|
||||||
|
- [SEEM: Segment Everything Everywhere All at Once](https://github.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once)
|
||||||
|
- [X-GPT: Conversational Visual Agent supported by X-Decoder](https://github.com/microsoft/X-Decoder/tree/xgpt)
|
||||||
|
- [GLIGEN: Open-Set Grounded Text-to-Image Generation](https://github.com/gligen/GLIGEN)
|
||||||
|
- [LLaVA: Large Language and Vision Assistant](https://github.com/haotian-liu/LLaVA)
|
||||||
|
|
||||||
|
<!-- Extensions | [Grounding DINO with Segment Anything](https://github.com/IDEA-Research/Grounded-Segment-Anything); [Grounding DINO with Stable Diffusion](demo/image_editing_with_groundingdino_stablediffusion.ipynb); [Grounding DINO with GLIGEN](demo/image_editing_with_groundingdino_gligen.ipynb) -->
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
<!-- Official PyTorch implementation of [Grounding DINO](https://arxiv.org/abs/2303.05499), a stronger open-set object detector. Code is available now! -->
|
||||||
|
|
||||||
|
|
||||||
|
## :bulb: Highlight
|
||||||
|
|
||||||
|
- **Open-Set Detection.** Detect **everything** with language!
|
||||||
|
- **High Performancce.** COCO zero-shot **52.5 AP** (training without COCO data!). COCO fine-tune **63.0 AP**.
|
||||||
|
- **Flexible.** Collaboration with Stable Diffusion for Image Editting.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## :fire: News
|
||||||
|
- **`2023/06/17`**: We provide an example to evaluate Grounding DINO on COCO zero-shot performance.
|
||||||
|
- **`2023/04/15`**: Refer to [CV in the Wild Readings](https://github.com/Computer-Vision-in-the-Wild/CVinW_Readings) for those who are interested in open-set recognition!
|
||||||
|
- **`2023/04/08`**: We release [demos](demo/image_editing_with_groundingdino_gligen.ipynb) to combine [Grounding DINO](https://arxiv.org/abs/2303.05499) with [GLIGEN](https://github.com/gligen/GLIGEN) for more controllable image editings.
|
||||||
|
- **`2023/04/08`**: We release [demos](demo/image_editing_with_groundingdino_stablediffusion.ipynb) to combine [Grounding DINO](https://arxiv.org/abs/2303.05499) with [Stable Diffusion](https://github.com/Stability-AI/StableDiffusion) for image editings.
|
||||||
|
- **`2023/04/06`**: We build a new demo by marrying GroundingDINO with [Segment-Anything](https://github.com/facebookresearch/segment-anything) named **[Grounded-Segment-Anything](https://github.com/IDEA-Research/Grounded-Segment-Anything)** aims to support segmentation in GroundingDINO.
|
||||||
|
- **`2023/03/28`**: A YouTube [video](https://youtu.be/cMa77r3YrDk) about Grounding DINO and basic object detection prompt engineering. [[SkalskiP](https://github.com/SkalskiP)]
|
||||||
|
- **`2023/03/28`**: Add a [demo](https://huggingface.co/spaces/ShilongLiu/Grounding_DINO_demo) on Hugging Face Space!
|
||||||
|
- **`2023/03/27`**: Support CPU-only mode. Now the model can run on machines without GPUs.
|
||||||
|
- **`2023/03/25`**: A [demo](https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/zero-shot-object-detection-with-grounding-dino.ipynb) for Grounding DINO is available at Colab. [[SkalskiP](https://github.com/SkalskiP)]
|
||||||
|
- **`2023/03/22`**: Code is available Now!
|
||||||
|
|
||||||
|
<details open>
|
||||||
|
<summary><font size="4">
|
||||||
|
Description
|
||||||
|
</font></summary>
|
||||||
|
<a href="https://arxiv.org/abs/2303.05499">Paper</a> introduction.
|
||||||
|
<img src=".asset/hero_figure.png" alt="ODinW" width="100%">
|
||||||
|
Marrying <a href="https://github.com/IDEA-Research/GroundingDINO">Grounding DINO</a> and <a href="https://github.com/gligen/GLIGEN">GLIGEN</a>
|
||||||
|
<img src="https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/GD_GLIGEN.png" alt="gd_gligen" width="100%">
|
||||||
|
</details>
|
||||||
|
|
||||||
|
## :star: Explanations/Tips for Grounding DINO Inputs and Outputs
|
||||||
|
- Grounding DINO accepts an `(image, text)` pair as inputs.
|
||||||
|
- It outputs `900` (by default) object boxes. Each box has similarity scores across all input words. (as shown in Figures below.)
|
||||||
|
- We defaultly choose the boxes whose highest similarities are higher than a `box_threshold`.
|
||||||
|
- We extract the words whose similarities are higher than the `text_threshold` as predicted labels.
|
||||||
|
- If you want to obtain objects of specific phrases, like the `dogs` in the sentence `two dogs with a stick.`, you can select the boxes with highest text similarities with `dogs` as final outputs.
|
||||||
|
- Note that each word can be split to **more than one** tokens with different tokenlizers. The number of words in a sentence may not equal to the number of text tokens.
|
||||||
|
- We suggest separating different category names with `.` for Grounding DINO.
|
||||||
|
![model_explain1](.asset/model_explan1.PNG)
|
||||||
|
![model_explain2](.asset/model_explan2.PNG)
|
||||||
|
|
||||||
|
## :label: TODO
|
||||||
|
|
||||||
|
- [x] Release inference code and demo.
|
||||||
|
- [x] Release checkpoints.
|
||||||
|
- [x] Grounding DINO with Stable Diffusion and GLIGEN demos.
|
||||||
|
- [ ] Release training codes.
|
||||||
|
|
||||||
|
## :hammer_and_wrench: Install
|
||||||
|
|
||||||
|
**Note:**
|
||||||
|
|
||||||
|
If you have a CUDA environment, please make sure the environment variable `CUDA_HOME` is set. It will be compiled under CPU-only mode if no CUDA available.
|
||||||
|
|
||||||
|
**Installation:**
|
||||||
|
|
||||||
|
Clone the GroundingDINO repository from GitHub.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/IDEA-Research/GroundingDINO.git
|
||||||
|
```
|
||||||
|
|
||||||
|
Change the current directory to the GroundingDINO folder.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd GroundingDINO/
|
||||||
|
```
|
||||||
|
|
||||||
|
Install the required dependencies in the current directory.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install -e .
|
||||||
|
```
|
||||||
|
|
||||||
|
Download pre-trained model weights.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
mkdir weights
|
||||||
|
cd weights
|
||||||
|
wget -q https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth
|
||||||
|
cd ..
|
||||||
|
```
|
||||||
|
|
||||||
|
## :arrow_forward: Demo
|
||||||
|
Check your GPU ID (only if you're using a GPU)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
nvidia-smi
|
||||||
|
```
|
||||||
|
Replace `{GPU ID}`, `image_you_want_to_detect.jpg`, and `"dir you want to save the output"` with appropriate values in the following command
|
||||||
|
```bash
|
||||||
|
CUDA_VISIBLE_DEVICES={GPU ID} python demo/inference_on_a_image.py \
|
||||||
|
-c groundingdino/config/GroundingDINO_SwinT_OGC.py \
|
||||||
|
-p weights/groundingdino_swint_ogc.pth \
|
||||||
|
-i image_you_want_to_detect.jpg \
|
||||||
|
-o "dir you want to save the output" \
|
||||||
|
-t "chair"
|
||||||
|
[--cpu-only] # open it for cpu mode
|
||||||
|
```
|
||||||
|
|
||||||
|
If you would like to specify the phrases to detect, here is a demo:
|
||||||
|
```bash
|
||||||
|
CUDA_VISIBLE_DEVICES={GPU ID} python demo/inference_on_a_image.py \
|
||||||
|
-c groundingdino/config/GroundingDINO_SwinT_OGC.py \
|
||||||
|
-p ./groundingdino_swint_ogc.pth \
|
||||||
|
-i .asset/cat_dog.jpeg \
|
||||||
|
-o logs/1111 \
|
||||||
|
-t "There is a cat and a dog in the image ." \
|
||||||
|
--token_spans "[[[9, 10], [11, 14]], [[19, 20], [21, 24]]]"
|
||||||
|
[--cpu-only] # open it for cpu mode
|
||||||
|
```
|
||||||
|
The token_spans specify the start and end positions of a phrases. For example, the first phrase is `[[9, 10], [11, 14]]`. `"There is a cat and a dog in the image ."[9:10] = 'a'`, `"There is a cat and a dog in the image ."[11:14] = 'cat'`. Hence it refers to the phrase `a cat` . Similarly, the `[[19, 20], [21, 24]]` refers to the phrase `a dog`.
|
||||||
|
|
||||||
|
See the `demo/inference_on_a_image.py` for more details.
|
||||||
|
|
||||||
|
**Running with Python:**
|
||||||
|
|
||||||
|
```python
|
||||||
|
from groundingdino.util.inference import load_model, load_image, predict, annotate
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
model = load_model("groundingdino/config/GroundingDINO_SwinT_OGC.py", "weights/groundingdino_swint_ogc.pth")
|
||||||
|
IMAGE_PATH = "weights/dog-3.jpeg"
|
||||||
|
TEXT_PROMPT = "chair . person . dog ."
|
||||||
|
BOX_TRESHOLD = 0.35
|
||||||
|
TEXT_TRESHOLD = 0.25
|
||||||
|
|
||||||
|
image_source, image = load_image(IMAGE_PATH)
|
||||||
|
|
||||||
|
boxes, logits, phrases = predict(
|
||||||
|
model=model,
|
||||||
|
image=image,
|
||||||
|
caption=TEXT_PROMPT,
|
||||||
|
box_threshold=BOX_TRESHOLD,
|
||||||
|
text_threshold=TEXT_TRESHOLD
|
||||||
|
)
|
||||||
|
|
||||||
|
annotated_frame = annotate(image_source=image_source, boxes=boxes, logits=logits, phrases=phrases)
|
||||||
|
cv2.imwrite("annotated_image.jpg", annotated_frame)
|
||||||
|
```
|
||||||
|
**Web UI**
|
||||||
|
|
||||||
|
We also provide a demo code to integrate Grounding DINO with Gradio Web UI. See the file `demo/gradio_app.py` for more details.
|
||||||
|
|
||||||
|
**Notebooks**
|
||||||
|
|
||||||
|
- We release [demos](demo/image_editing_with_groundingdino_gligen.ipynb) to combine [Grounding DINO](https://arxiv.org/abs/2303.05499) with [GLIGEN](https://github.com/gligen/GLIGEN) for more controllable image editings.
|
||||||
|
- We release [demos](demo/image_editing_with_groundingdino_stablediffusion.ipynb) to combine [Grounding DINO](https://arxiv.org/abs/2303.05499) with [Stable Diffusion](https://github.com/Stability-AI/StableDiffusion) for image editings.
|
||||||
|
|
||||||
|
## COCO Zero-shot Evaluations
|
||||||
|
|
||||||
|
We provide an example to evaluate Grounding DINO zero-shot performance on COCO. The results should be **48.5**.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
CUDA_VISIBLE_DEVICES=0 \
|
||||||
|
python demo/test_ap_on_coco.py \
|
||||||
|
-c groundingdino/config/GroundingDINO_SwinT_OGC.py \
|
||||||
|
-p weights/groundingdino_swint_ogc.pth \
|
||||||
|
--anno_path /path/to/annoataions/ie/instances_val2017.json \
|
||||||
|
--image_dir /path/to/imagedir/ie/val2017
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## :luggage: Checkpoints
|
||||||
|
|
||||||
|
<!-- insert a table -->
|
||||||
|
<table>
|
||||||
|
<thead>
|
||||||
|
<tr style="text-align: right;">
|
||||||
|
<th></th>
|
||||||
|
<th>name</th>
|
||||||
|
<th>backbone</th>
|
||||||
|
<th>Data</th>
|
||||||
|
<th>box AP on COCO</th>
|
||||||
|
<th>Checkpoint</th>
|
||||||
|
<th>Config</th>
|
||||||
|
</tr>
|
||||||
|
</thead>
|
||||||
|
<tbody>
|
||||||
|
<tr>
|
||||||
|
<th>1</th>
|
||||||
|
<td>GroundingDINO-T</td>
|
||||||
|
<td>Swin-T</td>
|
||||||
|
<td>O365,GoldG,Cap4M</td>
|
||||||
|
<td>48.4 (zero-shot) / 57.2 (fine-tune)</td>
|
||||||
|
<td><a href="https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth">GitHub link</a> | <a href="https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/groundingdino_swint_ogc.pth">HF link</a></td>
|
||||||
|
<td><a href="https://github.com/IDEA-Research/GroundingDINO/blob/main/groundingdino/config/GroundingDINO_SwinT_OGC.py">link</a></td>
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<th>2</th>
|
||||||
|
<td>GroundingDINO-B</td>
|
||||||
|
<td>Swin-B</td>
|
||||||
|
<td>COCO,O365,GoldG,Cap4M,OpenImage,ODinW-35,RefCOCO</td>
|
||||||
|
<td>56.7 </td>
|
||||||
|
<td><a href="https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha2/groundingdino_swinb_cogcoor.pth">GitHub link</a> | <a href="https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/groundingdino_swinb_cogcoor.pth">HF link</a>
|
||||||
|
<td><a href="https://github.com/IDEA-Research/GroundingDINO/blob/main/groundingdino/config/GroundingDINO_SwinB.cfg.py">link</a></td>
|
||||||
|
</tr>
|
||||||
|
</tbody>
|
||||||
|
</table>
|
||||||
|
|
||||||
|
## :medal_military: Results
|
||||||
|
|
||||||
|
<details open>
|
||||||
|
<summary><font size="4">
|
||||||
|
COCO Object Detection Results
|
||||||
|
</font></summary>
|
||||||
|
<img src=".asset/COCO.png" alt="COCO" width="100%">
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details open>
|
||||||
|
<summary><font size="4">
|
||||||
|
ODinW Object Detection Results
|
||||||
|
</font></summary>
|
||||||
|
<img src=".asset/ODinW.png" alt="ODinW" width="100%">
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details open>
|
||||||
|
<summary><font size="4">
|
||||||
|
Marrying Grounding DINO with <a href="https://github.com/Stability-AI/StableDiffusion">Stable Diffusion</a> for Image Editing
|
||||||
|
</font></summary>
|
||||||
|
See our example <a href="https://github.com/IDEA-Research/GroundingDINO/blob/main/demo/image_editing_with_groundingdino_stablediffusion.ipynb">notebook</a> for more details.
|
||||||
|
<img src=".asset/GD_SD.png" alt="GD_SD" width="100%">
|
||||||
|
</details>
|
||||||
|
|
||||||
|
|
||||||
|
<details open>
|
||||||
|
<summary><font size="4">
|
||||||
|
Marrying Grounding DINO with <a href="https://github.com/gligen/GLIGEN">GLIGEN</a> for more Detailed Image Editing.
|
||||||
|
</font></summary>
|
||||||
|
See our example <a href="https://github.com/IDEA-Research/GroundingDINO/blob/main/demo/image_editing_with_groundingdino_gligen.ipynb">notebook</a> for more details.
|
||||||
|
<img src=".asset/GD_GLIGEN.png" alt="GD_GLIGEN" width="100%">
|
||||||
|
</details>
|
||||||
|
|
||||||
|
## :sauropod: Model: Grounding DINO
|
||||||
|
|
||||||
|
Includes: a text backbone, an image backbone, a feature enhancer, a language-guided query selection, and a cross-modality decoder.
|
||||||
|
|
||||||
|
![arch](.asset/arch.png)
|
||||||
|
|
||||||
|
|
||||||
|
## :hearts: Acknowledgement
|
||||||
|
|
||||||
|
Our model is related to [DINO](https://github.com/IDEA-Research/DINO) and [GLIP](https://github.com/microsoft/GLIP). Thanks for their great work!
|
||||||
|
|
||||||
|
We also thank great previous work including DETR, Deformable DETR, SMCA, Conditional DETR, Anchor DETR, Dynamic DETR, DAB-DETR, DN-DETR, etc. More related work are available at [Awesome Detection Transformer](https://github.com/IDEACVR/awesome-detection-transformer). A new toolbox [detrex](https://github.com/IDEA-Research/detrex) is available as well.
|
||||||
|
|
||||||
|
Thanks [Stable Diffusion](https://github.com/Stability-AI/StableDiffusion) and [GLIGEN](https://github.com/gligen/GLIGEN) for their awesome models.
|
||||||
|
|
||||||
|
|
||||||
|
## :black_nib: Citation
|
||||||
|
|
||||||
|
If you find our work helpful for your research, please consider citing the following BibTeX entry.
|
||||||
|
|
||||||
|
```bibtex
|
||||||
|
@article{liu2023grounding,
|
||||||
|
title={Grounding dino: Marrying dino with grounded pre-training for open-set object detection},
|
||||||
|
author={Liu, Shilong and Zeng, Zhaoyang and Ren, Tianhe and Li, Feng and Zhang, Hao and Yang, Jie and Li, Chunyuan and Yang, Jianwei and Su, Hang and Zhu, Jun and others},
|
||||||
|
journal={arXiv preprint arXiv:2303.05499},
|
||||||
|
year={2023}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -0,0 +1 @@
|
|||||||
|
from GroundingDINO import groundingdino
|
@ -0,0 +1,125 @@
|
|||||||
|
import argparse
|
||||||
|
from functools import partial
|
||||||
|
import cv2
|
||||||
|
import requests
|
||||||
|
import os
|
||||||
|
from io import BytesIO
|
||||||
|
from PIL import Image
|
||||||
|
import numpy as np
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# prepare the environment
|
||||||
|
os.system("python setup.py build develop --user")
|
||||||
|
os.system("pip install packaging==21.3")
|
||||||
|
os.system("pip install gradio")
|
||||||
|
|
||||||
|
|
||||||
|
warnings.filterwarnings("ignore")
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
from groundingdino.models import build_model
|
||||||
|
from groundingdino.util.slconfig import SLConfig
|
||||||
|
from groundingdino.util.utils import clean_state_dict
|
||||||
|
from groundingdino.util.inference import annotate, load_image, predict
|
||||||
|
import groundingdino.datasets.transforms as T
|
||||||
|
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# Use this command for evaluate the Grounding DINO model
|
||||||
|
config_file = "groundingdino/config/GroundingDINO_SwinT_OGC.py"
|
||||||
|
ckpt_repo_id = "ShilongLiu/GroundingDINO"
|
||||||
|
ckpt_filenmae = "groundingdino_swint_ogc.pth"
|
||||||
|
|
||||||
|
|
||||||
|
def load_model_hf(model_config_path, repo_id, filename, device='cpu'):
|
||||||
|
args = SLConfig.fromfile(model_config_path)
|
||||||
|
model = build_model(args)
|
||||||
|
args.device = device
|
||||||
|
|
||||||
|
cache_file = hf_hub_download(repo_id=repo_id, filename=filename)
|
||||||
|
checkpoint = torch.load(cache_file, map_location='cpu')
|
||||||
|
log = model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False)
|
||||||
|
print("Model loaded from {} \n => {}".format(cache_file, log))
|
||||||
|
_ = model.eval()
|
||||||
|
return model
|
||||||
|
|
||||||
|
def image_transform_grounding(init_image):
|
||||||
|
transform = T.Compose([
|
||||||
|
T.RandomResize([800], max_size=1333),
|
||||||
|
T.ToTensor(),
|
||||||
|
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||||
|
])
|
||||||
|
image, _ = transform(init_image, None) # 3, h, w
|
||||||
|
return init_image, image
|
||||||
|
|
||||||
|
def image_transform_grounding_for_vis(init_image):
|
||||||
|
transform = T.Compose([
|
||||||
|
T.RandomResize([800], max_size=1333),
|
||||||
|
])
|
||||||
|
image, _ = transform(init_image, None) # 3, h, w
|
||||||
|
return image
|
||||||
|
|
||||||
|
model = load_model_hf(config_file, ckpt_repo_id, ckpt_filenmae)
|
||||||
|
|
||||||
|
def run_grounding(input_image, grounding_caption, box_threshold, text_threshold):
|
||||||
|
init_image = input_image.convert("RGB")
|
||||||
|
original_size = init_image.size
|
||||||
|
|
||||||
|
_, image_tensor = image_transform_grounding(init_image)
|
||||||
|
image_pil: Image = image_transform_grounding_for_vis(init_image)
|
||||||
|
|
||||||
|
# run grounidng
|
||||||
|
boxes, logits, phrases = predict(model, image_tensor, grounding_caption, box_threshold, text_threshold, device='cpu')
|
||||||
|
annotated_frame = annotate(image_source=np.asarray(image_pil), boxes=boxes, logits=logits, phrases=phrases)
|
||||||
|
image_with_box = Image.fromarray(cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB))
|
||||||
|
|
||||||
|
|
||||||
|
return image_with_box
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser("Grounding DINO demo", add_help=True)
|
||||||
|
parser.add_argument("--debug", action="store_true", help="using debug mode")
|
||||||
|
parser.add_argument("--share", action="store_true", help="share the app")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
block = gr.Blocks().queue()
|
||||||
|
with block:
|
||||||
|
gr.Markdown("# [Grounding DINO](https://github.com/IDEA-Research/GroundingDINO)")
|
||||||
|
gr.Markdown("### Open-World Detection with Grounding DINO")
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column():
|
||||||
|
input_image = gr.Image(source='upload', type="pil")
|
||||||
|
grounding_caption = gr.Textbox(label="Detection Prompt")
|
||||||
|
run_button = gr.Button(label="Run")
|
||||||
|
with gr.Accordion("Advanced options", open=False):
|
||||||
|
box_threshold = gr.Slider(
|
||||||
|
label="Box Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.001
|
||||||
|
)
|
||||||
|
text_threshold = gr.Slider(
|
||||||
|
label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.001
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Column():
|
||||||
|
gallery = gr.outputs.Image(
|
||||||
|
type="pil",
|
||||||
|
# label="grounding results"
|
||||||
|
).style(full_width=True, full_height=True)
|
||||||
|
# gallery = gr.Gallery(label="Generated images", show_label=False).style(
|
||||||
|
# grid=[1], height="auto", container=True, full_width=True, full_height=True)
|
||||||
|
|
||||||
|
run_button.click(fn=run_grounding, inputs=[
|
||||||
|
input_image, grounding_caption, box_threshold, text_threshold], outputs=[gallery])
|
||||||
|
|
||||||
|
|
||||||
|
block.launch(server_name='0.0.0.0', server_port=7579, debug=args.debug, share=args.share)
|
||||||
|
|
@ -0,0 +1,214 @@
|
|||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from PIL import Image, ImageDraw, ImageFont
|
||||||
|
|
||||||
|
import groundingdino.datasets.transforms as T
|
||||||
|
from groundingdino.models import build_model
|
||||||
|
from groundingdino.util import box_ops
|
||||||
|
from groundingdino.util.slconfig import SLConfig
|
||||||
|
from groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
|
||||||
|
from groundingdino.util.vl_utils import create_positive_map_from_span
|
||||||
|
|
||||||
|
|
||||||
|
def plot_boxes_to_image(image_pil, tgt):
|
||||||
|
H, W = tgt["size"]
|
||||||
|
boxes = tgt["boxes"]
|
||||||
|
labels = tgt["labels"]
|
||||||
|
assert len(boxes) == len(labels), "boxes and labels must have same length"
|
||||||
|
|
||||||
|
draw = ImageDraw.Draw(image_pil)
|
||||||
|
mask = Image.new("L", image_pil.size, 0)
|
||||||
|
mask_draw = ImageDraw.Draw(mask)
|
||||||
|
|
||||||
|
# draw boxes and masks
|
||||||
|
for box, label in zip(boxes, labels):
|
||||||
|
# from 0..1 to 0..W, 0..H
|
||||||
|
box = box * torch.Tensor([W, H, W, H])
|
||||||
|
# from xywh to xyxy
|
||||||
|
box[:2] -= box[2:] / 2
|
||||||
|
box[2:] += box[:2]
|
||||||
|
# random color
|
||||||
|
color = tuple(np.random.randint(0, 255, size=3).tolist())
|
||||||
|
# draw
|
||||||
|
x0, y0, x1, y1 = box
|
||||||
|
x0, y0, x1, y1 = int(x0), int(y0), int(x1), int(y1)
|
||||||
|
|
||||||
|
draw.rectangle([x0, y0, x1, y1], outline=color, width=6)
|
||||||
|
# draw.text((x0, y0), str(label), fill=color)
|
||||||
|
|
||||||
|
font = ImageFont.load_default()
|
||||||
|
if hasattr(font, "getbbox"):
|
||||||
|
bbox = draw.textbbox((x0, y0), str(label), font)
|
||||||
|
else:
|
||||||
|
w, h = draw.textsize(str(label), font)
|
||||||
|
bbox = (x0, y0, w + x0, y0 + h)
|
||||||
|
# bbox = draw.textbbox((x0, y0), str(label))
|
||||||
|
draw.rectangle(bbox, fill=color)
|
||||||
|
draw.text((x0, y0), str(label), fill="white")
|
||||||
|
|
||||||
|
mask_draw.rectangle([x0, y0, x1, y1], fill=255, width=6)
|
||||||
|
|
||||||
|
return image_pil, mask
|
||||||
|
|
||||||
|
|
||||||
|
def load_image(image_path):
|
||||||
|
# load image
|
||||||
|
image_pil = Image.open(image_path).convert("RGB") # load image
|
||||||
|
|
||||||
|
transform = T.Compose(
|
||||||
|
[
|
||||||
|
T.RandomResize([800], max_size=1333),
|
||||||
|
T.ToTensor(),
|
||||||
|
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
image, _ = transform(image_pil, None) # 3, h, w
|
||||||
|
return image_pil, image
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(model_config_path, model_checkpoint_path, cpu_only=False):
|
||||||
|
args = SLConfig.fromfile(model_config_path)
|
||||||
|
args.device = "cuda" if not cpu_only else "cpu"
|
||||||
|
model = build_model(args)
|
||||||
|
checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
|
||||||
|
load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
|
||||||
|
print(load_res)
|
||||||
|
_ = model.eval()
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def get_grounding_output(model, image, caption, box_threshold, text_threshold=None, with_logits=True, cpu_only=False, token_spans=None):
|
||||||
|
assert text_threshold is not None or token_spans is not None, "text_threshould and token_spans should not be None at the same time!"
|
||||||
|
caption = caption.lower()
|
||||||
|
caption = caption.strip()
|
||||||
|
if not caption.endswith("."):
|
||||||
|
caption = caption + "."
|
||||||
|
device = "cuda" if not cpu_only else "cpu"
|
||||||
|
model = model.to(device)
|
||||||
|
image = image.to(device)
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(image[None], captions=[caption])
|
||||||
|
logits = outputs["pred_logits"].sigmoid()[0] # (nq, 256)
|
||||||
|
boxes = outputs["pred_boxes"][0] # (nq, 4)
|
||||||
|
|
||||||
|
# filter output
|
||||||
|
if token_spans is None:
|
||||||
|
logits_filt = logits.cpu().clone()
|
||||||
|
boxes_filt = boxes.cpu().clone()
|
||||||
|
filt_mask = logits_filt.max(dim=1)[0] > box_threshold
|
||||||
|
logits_filt = logits_filt[filt_mask] # num_filt, 256
|
||||||
|
boxes_filt = boxes_filt[filt_mask] # num_filt, 4
|
||||||
|
|
||||||
|
# get phrase
|
||||||
|
tokenlizer = model.tokenizer
|
||||||
|
tokenized = tokenlizer(caption)
|
||||||
|
# build pred
|
||||||
|
pred_phrases = []
|
||||||
|
for logit, box in zip(logits_filt, boxes_filt):
|
||||||
|
pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer)
|
||||||
|
if with_logits:
|
||||||
|
pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})")
|
||||||
|
else:
|
||||||
|
pred_phrases.append(pred_phrase)
|
||||||
|
else:
|
||||||
|
# given-phrase mode
|
||||||
|
positive_maps = create_positive_map_from_span(
|
||||||
|
model.tokenizer(text_prompt),
|
||||||
|
token_span=token_spans
|
||||||
|
).to(image.device) # n_phrase, 256
|
||||||
|
|
||||||
|
logits_for_phrases = positive_maps @ logits.T # n_phrase, nq
|
||||||
|
all_logits = []
|
||||||
|
all_phrases = []
|
||||||
|
all_boxes = []
|
||||||
|
for (token_span, logit_phr) in zip(token_spans, logits_for_phrases):
|
||||||
|
# get phrase
|
||||||
|
phrase = ' '.join([caption[_s:_e] for (_s, _e) in token_span])
|
||||||
|
# get mask
|
||||||
|
filt_mask = logit_phr > box_threshold
|
||||||
|
# filt box
|
||||||
|
all_boxes.append(boxes[filt_mask])
|
||||||
|
# filt logits
|
||||||
|
all_logits.append(logit_phr[filt_mask])
|
||||||
|
if with_logits:
|
||||||
|
logit_phr_num = logit_phr[filt_mask]
|
||||||
|
all_phrases.extend([phrase + f"({str(logit.item())[:4]})" for logit in logit_phr_num])
|
||||||
|
else:
|
||||||
|
all_phrases.extend([phrase for _ in range(len(filt_mask))])
|
||||||
|
boxes_filt = torch.cat(all_boxes, dim=0).cpu()
|
||||||
|
pred_phrases = all_phrases
|
||||||
|
|
||||||
|
|
||||||
|
return boxes_filt, pred_phrases
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser("Grounding DINO example", add_help=True)
|
||||||
|
parser.add_argument("--config_file", "-c", type=str, required=True, help="path to config file")
|
||||||
|
parser.add_argument(
|
||||||
|
"--checkpoint_path", "-p", type=str, required=True, help="path to checkpoint file"
|
||||||
|
)
|
||||||
|
parser.add_argument("--image_path", "-i", type=str, required=True, help="path to image file")
|
||||||
|
parser.add_argument("--text_prompt", "-t", type=str, required=True, help="text prompt")
|
||||||
|
parser.add_argument(
|
||||||
|
"--output_dir", "-o", type=str, default="outputs", required=True, help="output directory"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument("--box_threshold", type=float, default=0.3, help="box threshold")
|
||||||
|
parser.add_argument("--text_threshold", type=float, default=0.25, help="text threshold")
|
||||||
|
parser.add_argument("--token_spans", type=str, default=None, help=
|
||||||
|
"The positions of start and end positions of phrases of interest. \
|
||||||
|
For example, a caption is 'a cat and a dog', \
|
||||||
|
if you would like to detect 'cat', the token_spans should be '[[[2, 5]], ]', since 'a cat and a dog'[2:5] is 'cat'. \
|
||||||
|
if you would like to detect 'a cat', the token_spans should be '[[[0, 1], [2, 5]], ]', since 'a cat and a dog'[0:1] is 'a', and 'a cat and a dog'[2:5] is 'cat'. \
|
||||||
|
")
|
||||||
|
|
||||||
|
parser.add_argument("--cpu-only", action="store_true", help="running on cpu only!, default=False")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# cfg
|
||||||
|
config_file = args.config_file # change the path of the model config file
|
||||||
|
checkpoint_path = args.checkpoint_path # change the path of the model
|
||||||
|
image_path = args.image_path
|
||||||
|
text_prompt = args.text_prompt
|
||||||
|
output_dir = args.output_dir
|
||||||
|
box_threshold = args.box_threshold
|
||||||
|
text_threshold = args.text_threshold
|
||||||
|
token_spans = args.token_spans
|
||||||
|
|
||||||
|
# make dir
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
# load image
|
||||||
|
image_pil, image = load_image(image_path)
|
||||||
|
# load model
|
||||||
|
model = load_model(config_file, checkpoint_path, cpu_only=args.cpu_only)
|
||||||
|
|
||||||
|
# visualize raw image
|
||||||
|
image_pil.save(os.path.join(output_dir, "raw_image.jpg"))
|
||||||
|
|
||||||
|
# set the text_threshold to None if token_spans is set.
|
||||||
|
if token_spans is not None:
|
||||||
|
text_threshold = None
|
||||||
|
print("Using token_spans. Set the text_threshold to None.")
|
||||||
|
|
||||||
|
|
||||||
|
# run model
|
||||||
|
boxes_filt, pred_phrases = get_grounding_output(
|
||||||
|
model, image, text_prompt, box_threshold, text_threshold, cpu_only=args.cpu_only, token_spans=eval(token_spans)
|
||||||
|
)
|
||||||
|
|
||||||
|
# visualize pred
|
||||||
|
size = image_pil.size
|
||||||
|
pred_dict = {
|
||||||
|
"boxes": boxes_filt,
|
||||||
|
"size": [size[1], size[0]], # H,W
|
||||||
|
"labels": pred_phrases,
|
||||||
|
}
|
||||||
|
# import ipdb; ipdb.set_trace()
|
||||||
|
image_with_box = plot_boxes_to_image(image_pil, pred_dict)[0]
|
||||||
|
image_with_box.save(os.path.join(output_dir, "pred.jpg"))
|
@ -0,0 +1,233 @@
|
|||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.utils.data import DataLoader, DistributedSampler
|
||||||
|
|
||||||
|
from groundingdino.models import build_model
|
||||||
|
import groundingdino.datasets.transforms as T
|
||||||
|
from groundingdino.util import box_ops, get_tokenlizer
|
||||||
|
from groundingdino.util.misc import clean_state_dict, collate_fn
|
||||||
|
from groundingdino.util.slconfig import SLConfig
|
||||||
|
|
||||||
|
# from torchvision.datasets import CocoDetection
|
||||||
|
import torchvision
|
||||||
|
|
||||||
|
from groundingdino.util.vl_utils import build_captions_and_token_span, create_positive_map_from_span
|
||||||
|
from groundingdino.datasets.cocogrounding_eval import CocoGroundingEvaluator
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(model_config_path: str, model_checkpoint_path: str, device: str = "cuda"):
|
||||||
|
args = SLConfig.fromfile(model_config_path)
|
||||||
|
args.device = device
|
||||||
|
model = build_model(args)
|
||||||
|
checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
|
||||||
|
model.load_state_dict(clean_state_dict(checkpoint["ema_model"]), strict=False)
|
||||||
|
model.eval()
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
class CocoDetection(torchvision.datasets.CocoDetection):
|
||||||
|
def __init__(self, img_folder, ann_file, transforms):
|
||||||
|
super().__init__(img_folder, ann_file)
|
||||||
|
self._transforms = transforms
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
img, target = super().__getitem__(idx) # target: list
|
||||||
|
|
||||||
|
# import ipdb; ipdb.set_trace()
|
||||||
|
|
||||||
|
w, h = img.size
|
||||||
|
boxes = [obj["bbox"] for obj in target]
|
||||||
|
boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)
|
||||||
|
boxes[:, 2:] += boxes[:, :2] # xywh -> xyxy
|
||||||
|
boxes[:, 0::2].clamp_(min=0, max=w)
|
||||||
|
boxes[:, 1::2].clamp_(min=0, max=h)
|
||||||
|
# filt invalid boxes/masks/keypoints
|
||||||
|
keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
|
||||||
|
boxes = boxes[keep]
|
||||||
|
|
||||||
|
target_new = {}
|
||||||
|
image_id = self.ids[idx]
|
||||||
|
target_new["image_id"] = image_id
|
||||||
|
target_new["boxes"] = boxes
|
||||||
|
target_new["orig_size"] = torch.as_tensor([int(h), int(w)])
|
||||||
|
|
||||||
|
if self._transforms is not None:
|
||||||
|
img, target = self._transforms(img, target_new)
|
||||||
|
|
||||||
|
return img, target
|
||||||
|
|
||||||
|
|
||||||
|
class PostProcessCocoGrounding(nn.Module):
|
||||||
|
""" This module converts the model's output into the format expected by the coco api"""
|
||||||
|
|
||||||
|
def __init__(self, num_select=300, coco_api=None, tokenlizer=None) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.num_select = num_select
|
||||||
|
|
||||||
|
assert coco_api is not None
|
||||||
|
category_dict = coco_api.dataset['categories']
|
||||||
|
cat_list = [item['name'] for item in category_dict]
|
||||||
|
captions, cat2tokenspan = build_captions_and_token_span(cat_list, True)
|
||||||
|
tokenspanlist = [cat2tokenspan[cat] for cat in cat_list]
|
||||||
|
positive_map = create_positive_map_from_span(
|
||||||
|
tokenlizer(captions), tokenspanlist) # 80, 256. normed
|
||||||
|
|
||||||
|
id_map = {0: 1, 1: 2, 2: 3, 3: 4, 4: 5, 5: 6, 6: 7, 7: 8, 8: 9, 9: 10, 10: 11, 11: 13, 12: 14, 13: 15, 14: 16, 15: 17, 16: 18, 17: 19, 18: 20, 19: 21, 20: 22, 21: 23, 22: 24, 23: 25, 24: 27, 25: 28, 26: 31, 27: 32, 28: 33, 29: 34, 30: 35, 31: 36, 32: 37, 33: 38, 34: 39, 35: 40, 36: 41, 37: 42, 38: 43, 39: 44, 40: 46,
|
||||||
|
41: 47, 42: 48, 43: 49, 44: 50, 45: 51, 46: 52, 47: 53, 48: 54, 49: 55, 50: 56, 51: 57, 52: 58, 53: 59, 54: 60, 55: 61, 56: 62, 57: 63, 58: 64, 59: 65, 60: 67, 61: 70, 62: 72, 63: 73, 64: 74, 65: 75, 66: 76, 67: 77, 68: 78, 69: 79, 70: 80, 71: 81, 72: 82, 73: 84, 74: 85, 75: 86, 76: 87, 77: 88, 78: 89, 79: 90}
|
||||||
|
|
||||||
|
# build a mapping from label_id to pos_map
|
||||||
|
new_pos_map = torch.zeros((91, 256))
|
||||||
|
for k, v in id_map.items():
|
||||||
|
new_pos_map[v] = positive_map[k]
|
||||||
|
self.positive_map = new_pos_map
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def forward(self, outputs, target_sizes, not_to_xyxy=False):
|
||||||
|
""" Perform the computation
|
||||||
|
Parameters:
|
||||||
|
outputs: raw outputs of the model
|
||||||
|
target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch
|
||||||
|
For evaluation, this must be the original image size (before any data augmentation)
|
||||||
|
For visualization, this should be the image size after data augment, but before padding
|
||||||
|
"""
|
||||||
|
num_select = self.num_select
|
||||||
|
out_logits, out_bbox = outputs['pred_logits'], outputs['pred_boxes']
|
||||||
|
|
||||||
|
# pos map to logit
|
||||||
|
prob_to_token = out_logits.sigmoid() # bs, 100, 256
|
||||||
|
pos_maps = self.positive_map.to(prob_to_token.device)
|
||||||
|
# (bs, 100, 256) @ (91, 256).T -> (bs, 100, 91)
|
||||||
|
prob_to_label = prob_to_token @ pos_maps.T
|
||||||
|
|
||||||
|
# if os.environ.get('IPDB_SHILONG_DEBUG', None) == 'INFO':
|
||||||
|
# import ipdb; ipdb.set_trace()
|
||||||
|
|
||||||
|
assert len(out_logits) == len(target_sizes)
|
||||||
|
assert target_sizes.shape[1] == 2
|
||||||
|
|
||||||
|
prob = prob_to_label
|
||||||
|
topk_values, topk_indexes = torch.topk(
|
||||||
|
prob.view(out_logits.shape[0], -1), num_select, dim=1)
|
||||||
|
scores = topk_values
|
||||||
|
topk_boxes = topk_indexes // prob.shape[2]
|
||||||
|
labels = topk_indexes % prob.shape[2]
|
||||||
|
|
||||||
|
if not_to_xyxy:
|
||||||
|
boxes = out_bbox
|
||||||
|
else:
|
||||||
|
boxes = box_ops.box_cxcywh_to_xyxy(out_bbox)
|
||||||
|
|
||||||
|
boxes = torch.gather(
|
||||||
|
boxes, 1, topk_boxes.unsqueeze(-1).repeat(1, 1, 4))
|
||||||
|
|
||||||
|
# and from relative [0, 1] to absolute [0, height] coordinates
|
||||||
|
img_h, img_w = target_sizes.unbind(1)
|
||||||
|
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
|
||||||
|
boxes = boxes * scale_fct[:, None, :]
|
||||||
|
|
||||||
|
results = [{'scores': s, 'labels': l, 'boxes': b}
|
||||||
|
for s, l, b in zip(scores, labels, boxes)]
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
# config
|
||||||
|
cfg = SLConfig.fromfile(args.config_file)
|
||||||
|
|
||||||
|
# build model
|
||||||
|
model = load_model(args.config_file, args.checkpoint_path)
|
||||||
|
model = model.to(args.device)
|
||||||
|
model = model.eval()
|
||||||
|
|
||||||
|
# build dataloader
|
||||||
|
transform = T.Compose(
|
||||||
|
[
|
||||||
|
T.RandomResize([800], max_size=1333),
|
||||||
|
T.ToTensor(),
|
||||||
|
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
dataset = CocoDetection(
|
||||||
|
args.image_dir, args.anno_path, transforms=transform)
|
||||||
|
data_loader = DataLoader(
|
||||||
|
dataset, batch_size=1, shuffle=False, num_workers=args.num_workers, collate_fn=collate_fn)
|
||||||
|
|
||||||
|
# build post processor
|
||||||
|
tokenlizer = get_tokenlizer.get_tokenlizer(cfg.text_encoder_type)
|
||||||
|
postprocessor = PostProcessCocoGrounding(
|
||||||
|
coco_api=dataset.coco, tokenlizer=tokenlizer)
|
||||||
|
|
||||||
|
# build evaluator
|
||||||
|
evaluator = CocoGroundingEvaluator(
|
||||||
|
dataset.coco, iou_types=("bbox",), useCats=True)
|
||||||
|
|
||||||
|
# build captions
|
||||||
|
category_dict = dataset.coco.dataset['categories']
|
||||||
|
cat_list = [item['name'] for item in category_dict]
|
||||||
|
caption = " . ".join(cat_list) + ' .'
|
||||||
|
print("Input text prompt:", caption)
|
||||||
|
|
||||||
|
# run inference
|
||||||
|
start = time.time()
|
||||||
|
for i, (images, targets) in enumerate(data_loader):
|
||||||
|
# get images and captions
|
||||||
|
images = images.tensors.to(args.device)
|
||||||
|
bs = images.shape[0]
|
||||||
|
input_captions = [caption] * bs
|
||||||
|
|
||||||
|
# feed to the model
|
||||||
|
outputs = model(images, captions=input_captions)
|
||||||
|
|
||||||
|
orig_target_sizes = torch.stack(
|
||||||
|
[t["orig_size"] for t in targets], dim=0).to(images.device)
|
||||||
|
results = postprocessor(outputs, orig_target_sizes)
|
||||||
|
cocogrounding_res = {
|
||||||
|
target["image_id"]: output for target, output in zip(targets, results)}
|
||||||
|
evaluator.update(cocogrounding_res)
|
||||||
|
|
||||||
|
if (i+1) % 30 == 0:
|
||||||
|
used_time = time.time() - start
|
||||||
|
eta = len(data_loader) / (i+1e-5) * used_time - used_time
|
||||||
|
print(
|
||||||
|
f"processed {i}/{len(data_loader)} images. time: {used_time:.2f}s, ETA: {eta:.2f}s")
|
||||||
|
|
||||||
|
evaluator.synchronize_between_processes()
|
||||||
|
evaluator.accumulate()
|
||||||
|
evaluator.summarize()
|
||||||
|
|
||||||
|
print("Final results:", evaluator.coco_eval["bbox"].stats.tolist())
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
"Grounding DINO eval on COCO", add_help=True)
|
||||||
|
# load model
|
||||||
|
parser.add_argument("--config_file", "-c", type=str,
|
||||||
|
required=True, help="path to config file")
|
||||||
|
parser.add_argument(
|
||||||
|
"--checkpoint_path", "-p", type=str, required=True, help="path to checkpoint file"
|
||||||
|
)
|
||||||
|
parser.add_argument("--device", type=str, default="cuda",
|
||||||
|
help="running device (default: cuda)")
|
||||||
|
|
||||||
|
# post processing
|
||||||
|
parser.add_argument("--num_select", type=int, default=300,
|
||||||
|
help="number of topk to select")
|
||||||
|
|
||||||
|
# coco info
|
||||||
|
parser.add_argument("--anno_path", type=str,
|
||||||
|
required=True, help="coco root")
|
||||||
|
parser.add_argument("--image_dir", type=str,
|
||||||
|
required=True, help="coco image dir")
|
||||||
|
parser.add_argument("--num_workers", type=int, default=4,
|
||||||
|
help="number of workers for dataloader")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
main(args)
|
@ -0,0 +1,43 @@
|
|||||||
|
batch_size = 1
|
||||||
|
modelname = "groundingdino"
|
||||||
|
backbone = "swin_B_384_22k"
|
||||||
|
position_embedding = "sine"
|
||||||
|
pe_temperatureH = 20
|
||||||
|
pe_temperatureW = 20
|
||||||
|
return_interm_indices = [1, 2, 3]
|
||||||
|
backbone_freeze_keywords = None
|
||||||
|
enc_layers = 6
|
||||||
|
dec_layers = 6
|
||||||
|
pre_norm = False
|
||||||
|
dim_feedforward = 2048
|
||||||
|
hidden_dim = 256
|
||||||
|
dropout = 0.0
|
||||||
|
nheads = 8
|
||||||
|
num_queries = 900
|
||||||
|
query_dim = 4
|
||||||
|
num_patterns = 0
|
||||||
|
num_feature_levels = 4
|
||||||
|
enc_n_points = 4
|
||||||
|
dec_n_points = 4
|
||||||
|
two_stage_type = "standard"
|
||||||
|
two_stage_bbox_embed_share = False
|
||||||
|
two_stage_class_embed_share = False
|
||||||
|
transformer_activation = "relu"
|
||||||
|
dec_pred_bbox_embed_share = True
|
||||||
|
dn_box_noise_scale = 1.0
|
||||||
|
dn_label_noise_ratio = 0.5
|
||||||
|
dn_label_coef = 1.0
|
||||||
|
dn_bbox_coef = 1.0
|
||||||
|
embed_init_tgt = True
|
||||||
|
dn_labelbook_size = 2000
|
||||||
|
max_text_len = 256
|
||||||
|
text_encoder_type = "bert-base-uncased"
|
||||||
|
use_text_enhancer = True
|
||||||
|
use_fusion_layer = True
|
||||||
|
use_checkpoint = True
|
||||||
|
use_transformer_ckpt = True
|
||||||
|
use_text_cross_attention = True
|
||||||
|
text_dropout = 0.0
|
||||||
|
fusion_dropout = 0.0
|
||||||
|
fusion_droppath = 0.1
|
||||||
|
sub_sentence_present = True
|
@ -0,0 +1,43 @@
|
|||||||
|
batch_size = 1
|
||||||
|
modelname = "groundingdino"
|
||||||
|
backbone = "swin_T_224_1k"
|
||||||
|
position_embedding = "sine"
|
||||||
|
pe_temperatureH = 20
|
||||||
|
pe_temperatureW = 20
|
||||||
|
return_interm_indices = [1, 2, 3]
|
||||||
|
backbone_freeze_keywords = None
|
||||||
|
enc_layers = 6
|
||||||
|
dec_layers = 6
|
||||||
|
pre_norm = False
|
||||||
|
dim_feedforward = 2048
|
||||||
|
hidden_dim = 256
|
||||||
|
dropout = 0.0
|
||||||
|
nheads = 8
|
||||||
|
num_queries = 900
|
||||||
|
query_dim = 4
|
||||||
|
num_patterns = 0
|
||||||
|
num_feature_levels = 4
|
||||||
|
enc_n_points = 4
|
||||||
|
dec_n_points = 4
|
||||||
|
two_stage_type = "standard"
|
||||||
|
two_stage_bbox_embed_share = False
|
||||||
|
two_stage_class_embed_share = False
|
||||||
|
transformer_activation = "relu"
|
||||||
|
dec_pred_bbox_embed_share = True
|
||||||
|
dn_box_noise_scale = 1.0
|
||||||
|
dn_label_noise_ratio = 0.5
|
||||||
|
dn_label_coef = 1.0
|
||||||
|
dn_bbox_coef = 1.0
|
||||||
|
embed_init_tgt = True
|
||||||
|
dn_labelbook_size = 2000
|
||||||
|
max_text_len = 256
|
||||||
|
text_encoder_type = "bert-base-uncased"
|
||||||
|
use_text_enhancer = True
|
||||||
|
use_fusion_layer = True
|
||||||
|
use_checkpoint = True
|
||||||
|
use_transformer_ckpt = True
|
||||||
|
use_text_cross_attention = True
|
||||||
|
text_dropout = 0.0
|
||||||
|
fusion_dropout = 0.0
|
||||||
|
fusion_droppath = 0.1
|
||||||
|
sub_sentence_present = True
|
@ -0,0 +1,269 @@
|
|||||||
|
# ------------------------------------------------------------------------
|
||||||
|
# Grounding DINO. Midified by Shilong Liu.
|
||||||
|
# url: https://github.com/IDEA-Research/GroundingDINO
|
||||||
|
# Copyright (c) 2023 IDEA. All Rights Reserved.
|
||||||
|
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
||||||
|
# ------------------------------------------------------------------------
|
||||||
|
# Copyright (c) Aishwarya Kamath & Nicolas Carion. Licensed under the Apache License 2.0. All Rights Reserved
|
||||||
|
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||||
|
"""
|
||||||
|
COCO evaluator that works in distributed mode.
|
||||||
|
|
||||||
|
Mostly copy-paste from https://github.com/pytorch/vision/blob/edfd5a7/references/detection/coco_eval.py
|
||||||
|
The difference is that there is less copy-pasting from pycocotools
|
||||||
|
in the end of the file, as python3 can suppress prints with contextlib
|
||||||
|
"""
|
||||||
|
import contextlib
|
||||||
|
import copy
|
||||||
|
import os
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pycocotools.mask as mask_util
|
||||||
|
import torch
|
||||||
|
from pycocotools.coco import COCO
|
||||||
|
from pycocotools.cocoeval import COCOeval
|
||||||
|
|
||||||
|
from groundingdino.util.misc import all_gather
|
||||||
|
|
||||||
|
|
||||||
|
class CocoGroundingEvaluator(object):
|
||||||
|
def __init__(self, coco_gt, iou_types, useCats=True):
|
||||||
|
assert isinstance(iou_types, (list, tuple))
|
||||||
|
coco_gt = copy.deepcopy(coco_gt)
|
||||||
|
self.coco_gt = coco_gt
|
||||||
|
|
||||||
|
self.iou_types = iou_types
|
||||||
|
self.coco_eval = {}
|
||||||
|
for iou_type in iou_types:
|
||||||
|
self.coco_eval[iou_type] = COCOeval(coco_gt, iouType=iou_type)
|
||||||
|
self.coco_eval[iou_type].useCats = useCats
|
||||||
|
|
||||||
|
self.img_ids = []
|
||||||
|
self.eval_imgs = {k: [] for k in iou_types}
|
||||||
|
self.useCats = useCats
|
||||||
|
|
||||||
|
def update(self, predictions):
|
||||||
|
img_ids = list(np.unique(list(predictions.keys())))
|
||||||
|
self.img_ids.extend(img_ids)
|
||||||
|
|
||||||
|
for iou_type in self.iou_types:
|
||||||
|
results = self.prepare(predictions, iou_type)
|
||||||
|
|
||||||
|
# suppress pycocotools prints
|
||||||
|
with open(os.devnull, "w") as devnull:
|
||||||
|
with contextlib.redirect_stdout(devnull):
|
||||||
|
coco_dt = COCO.loadRes(self.coco_gt, results) if results else COCO()
|
||||||
|
|
||||||
|
coco_eval = self.coco_eval[iou_type]
|
||||||
|
|
||||||
|
coco_eval.cocoDt = coco_dt
|
||||||
|
coco_eval.params.imgIds = list(img_ids)
|
||||||
|
coco_eval.params.useCats = self.useCats
|
||||||
|
img_ids, eval_imgs = evaluate(coco_eval)
|
||||||
|
|
||||||
|
self.eval_imgs[iou_type].append(eval_imgs)
|
||||||
|
|
||||||
|
def synchronize_between_processes(self):
|
||||||
|
for iou_type in self.iou_types:
|
||||||
|
self.eval_imgs[iou_type] = np.concatenate(self.eval_imgs[iou_type], 2)
|
||||||
|
create_common_coco_eval(self.coco_eval[iou_type], self.img_ids, self.eval_imgs[iou_type])
|
||||||
|
|
||||||
|
def accumulate(self):
|
||||||
|
for coco_eval in self.coco_eval.values():
|
||||||
|
coco_eval.accumulate()
|
||||||
|
|
||||||
|
def summarize(self):
|
||||||
|
for iou_type, coco_eval in self.coco_eval.items():
|
||||||
|
print("IoU metric: {}".format(iou_type))
|
||||||
|
coco_eval.summarize()
|
||||||
|
|
||||||
|
def prepare(self, predictions, iou_type):
|
||||||
|
if iou_type == "bbox":
|
||||||
|
return self.prepare_for_coco_detection(predictions)
|
||||||
|
elif iou_type == "segm":
|
||||||
|
return self.prepare_for_coco_segmentation(predictions)
|
||||||
|
elif iou_type == "keypoints":
|
||||||
|
return self.prepare_for_coco_keypoint(predictions)
|
||||||
|
else:
|
||||||
|
raise ValueError("Unknown iou type {}".format(iou_type))
|
||||||
|
|
||||||
|
def prepare_for_coco_detection(self, predictions):
|
||||||
|
coco_results = []
|
||||||
|
for original_id, prediction in predictions.items():
|
||||||
|
if len(prediction) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
boxes = prediction["boxes"]
|
||||||
|
boxes = convert_to_xywh(boxes).tolist()
|
||||||
|
scores = prediction["scores"].tolist()
|
||||||
|
labels = prediction["labels"].tolist()
|
||||||
|
|
||||||
|
coco_results.extend(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"image_id": original_id,
|
||||||
|
"category_id": labels[k],
|
||||||
|
"bbox": box,
|
||||||
|
"score": scores[k],
|
||||||
|
}
|
||||||
|
for k, box in enumerate(boxes)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return coco_results
|
||||||
|
|
||||||
|
def prepare_for_coco_segmentation(self, predictions):
|
||||||
|
coco_results = []
|
||||||
|
for original_id, prediction in predictions.items():
|
||||||
|
if len(prediction) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
scores = prediction["scores"]
|
||||||
|
labels = prediction["labels"]
|
||||||
|
masks = prediction["masks"]
|
||||||
|
|
||||||
|
masks = masks > 0.5
|
||||||
|
|
||||||
|
scores = prediction["scores"].tolist()
|
||||||
|
labels = prediction["labels"].tolist()
|
||||||
|
|
||||||
|
rles = [
|
||||||
|
mask_util.encode(np.array(mask[0, :, :, np.newaxis], dtype=np.uint8, order="F"))[0]
|
||||||
|
for mask in masks
|
||||||
|
]
|
||||||
|
for rle in rles:
|
||||||
|
rle["counts"] = rle["counts"].decode("utf-8")
|
||||||
|
|
||||||
|
coco_results.extend(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"image_id": original_id,
|
||||||
|
"category_id": labels[k],
|
||||||
|
"segmentation": rle,
|
||||||
|
"score": scores[k],
|
||||||
|
}
|
||||||
|
for k, rle in enumerate(rles)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return coco_results
|
||||||
|
|
||||||
|
def prepare_for_coco_keypoint(self, predictions):
|
||||||
|
coco_results = []
|
||||||
|
for original_id, prediction in predictions.items():
|
||||||
|
if len(prediction) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
boxes = prediction["boxes"]
|
||||||
|
boxes = convert_to_xywh(boxes).tolist()
|
||||||
|
scores = prediction["scores"].tolist()
|
||||||
|
labels = prediction["labels"].tolist()
|
||||||
|
keypoints = prediction["keypoints"]
|
||||||
|
keypoints = keypoints.flatten(start_dim=1).tolist()
|
||||||
|
|
||||||
|
coco_results.extend(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"image_id": original_id,
|
||||||
|
"category_id": labels[k],
|
||||||
|
"keypoints": keypoint,
|
||||||
|
"score": scores[k],
|
||||||
|
}
|
||||||
|
for k, keypoint in enumerate(keypoints)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return coco_results
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_xywh(boxes):
|
||||||
|
xmin, ymin, xmax, ymax = boxes.unbind(1)
|
||||||
|
return torch.stack((xmin, ymin, xmax - xmin, ymax - ymin), dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
def merge(img_ids, eval_imgs):
|
||||||
|
all_img_ids = all_gather(img_ids)
|
||||||
|
all_eval_imgs = all_gather(eval_imgs)
|
||||||
|
|
||||||
|
merged_img_ids = []
|
||||||
|
for p in all_img_ids:
|
||||||
|
merged_img_ids.extend(p)
|
||||||
|
|
||||||
|
merged_eval_imgs = []
|
||||||
|
for p in all_eval_imgs:
|
||||||
|
merged_eval_imgs.append(p)
|
||||||
|
|
||||||
|
merged_img_ids = np.array(merged_img_ids)
|
||||||
|
merged_eval_imgs = np.concatenate(merged_eval_imgs, 2)
|
||||||
|
|
||||||
|
# keep only unique (and in sorted order) images
|
||||||
|
merged_img_ids, idx = np.unique(merged_img_ids, return_index=True)
|
||||||
|
merged_eval_imgs = merged_eval_imgs[..., idx]
|
||||||
|
|
||||||
|
return merged_img_ids, merged_eval_imgs
|
||||||
|
|
||||||
|
|
||||||
|
def create_common_coco_eval(coco_eval, img_ids, eval_imgs):
|
||||||
|
img_ids, eval_imgs = merge(img_ids, eval_imgs)
|
||||||
|
img_ids = list(img_ids)
|
||||||
|
eval_imgs = list(eval_imgs.flatten())
|
||||||
|
|
||||||
|
coco_eval.evalImgs = eval_imgs
|
||||||
|
coco_eval.params.imgIds = img_ids
|
||||||
|
coco_eval._paramsEval = copy.deepcopy(coco_eval.params)
|
||||||
|
|
||||||
|
|
||||||
|
#################################################################
|
||||||
|
# From pycocotools, just removed the prints and fixed
|
||||||
|
# a Python3 bug about unicode not defined
|
||||||
|
#################################################################
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate(self):
|
||||||
|
"""
|
||||||
|
Run per image evaluation on given images and store results (a list of dict) in self.evalImgs
|
||||||
|
:return: None
|
||||||
|
"""
|
||||||
|
# tic = time.time()
|
||||||
|
# print('Running per image evaluation...')
|
||||||
|
p = self.params
|
||||||
|
# add backward compatibility if useSegm is specified in params
|
||||||
|
if p.useSegm is not None:
|
||||||
|
p.iouType = "segm" if p.useSegm == 1 else "bbox"
|
||||||
|
print("useSegm (deprecated) is not None. Running {} evaluation".format(p.iouType))
|
||||||
|
# print('Evaluate annotation type *{}*'.format(p.iouType))
|
||||||
|
p.imgIds = list(np.unique(p.imgIds))
|
||||||
|
if p.useCats:
|
||||||
|
p.catIds = list(np.unique(p.catIds))
|
||||||
|
p.maxDets = sorted(p.maxDets)
|
||||||
|
self.params = p
|
||||||
|
|
||||||
|
self._prepare()
|
||||||
|
# loop through images, area range, max detection number
|
||||||
|
catIds = p.catIds if p.useCats else [-1]
|
||||||
|
|
||||||
|
if p.iouType == "segm" or p.iouType == "bbox":
|
||||||
|
computeIoU = self.computeIoU
|
||||||
|
elif p.iouType == "keypoints":
|
||||||
|
computeIoU = self.computeOks
|
||||||
|
self.ious = {
|
||||||
|
(imgId, catId): computeIoU(imgId, catId)
|
||||||
|
for imgId in p.imgIds
|
||||||
|
for catId in catIds}
|
||||||
|
|
||||||
|
evaluateImg = self.evaluateImg
|
||||||
|
maxDet = p.maxDets[-1]
|
||||||
|
evalImgs = [
|
||||||
|
evaluateImg(imgId, catId, areaRng, maxDet)
|
||||||
|
for catId in catIds
|
||||||
|
for areaRng in p.areaRng
|
||||||
|
for imgId in p.imgIds
|
||||||
|
]
|
||||||
|
# this is NOT in the pycocotools code, but could be done outside
|
||||||
|
evalImgs = np.asarray(evalImgs).reshape(len(catIds), len(p.areaRng), len(p.imgIds))
|
||||||
|
self._paramsEval = copy.deepcopy(self.params)
|
||||||
|
# toc = time.time()
|
||||||
|
# print('DONE (t={:0.2f}s).'.format(toc-tic))
|
||||||
|
return p.imgIds, evalImgs
|
||||||
|
|
||||||
|
|
||||||
|
#################################################################
|
||||||
|
# end of straight copy from pycocotools, just removing the prints
|
||||||
|
#################################################################
|
@ -0,0 +1,311 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||||
|
"""
|
||||||
|
Transforms and data augmentation for both image + bbox.
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
|
||||||
|
import PIL
|
||||||
|
import torch
|
||||||
|
import torchvision.transforms as T
|
||||||
|
import torchvision.transforms.functional as F
|
||||||
|
|
||||||
|
from groundingdino.util.box_ops import box_xyxy_to_cxcywh
|
||||||
|
from groundingdino.util.misc import interpolate
|
||||||
|
|
||||||
|
|
||||||
|
def crop(image, target, region):
|
||||||
|
cropped_image = F.crop(image, *region)
|
||||||
|
|
||||||
|
target = target.copy()
|
||||||
|
i, j, h, w = region
|
||||||
|
|
||||||
|
# should we do something wrt the original size?
|
||||||
|
target["size"] = torch.tensor([h, w])
|
||||||
|
|
||||||
|
fields = ["labels", "area", "iscrowd", "positive_map"]
|
||||||
|
|
||||||
|
if "boxes" in target:
|
||||||
|
boxes = target["boxes"]
|
||||||
|
max_size = torch.as_tensor([w, h], dtype=torch.float32)
|
||||||
|
cropped_boxes = boxes - torch.as_tensor([j, i, j, i])
|
||||||
|
cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size)
|
||||||
|
cropped_boxes = cropped_boxes.clamp(min=0)
|
||||||
|
area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1)
|
||||||
|
target["boxes"] = cropped_boxes.reshape(-1, 4)
|
||||||
|
target["area"] = area
|
||||||
|
fields.append("boxes")
|
||||||
|
|
||||||
|
if "masks" in target:
|
||||||
|
# FIXME should we update the area here if there are no boxes?
|
||||||
|
target["masks"] = target["masks"][:, i : i + h, j : j + w]
|
||||||
|
fields.append("masks")
|
||||||
|
|
||||||
|
# remove elements for which the boxes or masks that have zero area
|
||||||
|
if "boxes" in target or "masks" in target:
|
||||||
|
# favor boxes selection when defining which elements to keep
|
||||||
|
# this is compatible with previous implementation
|
||||||
|
if "boxes" in target:
|
||||||
|
cropped_boxes = target["boxes"].reshape(-1, 2, 2)
|
||||||
|
keep = torch.all(cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1)
|
||||||
|
else:
|
||||||
|
keep = target["masks"].flatten(1).any(1)
|
||||||
|
|
||||||
|
for field in fields:
|
||||||
|
if field in target:
|
||||||
|
target[field] = target[field][keep]
|
||||||
|
|
||||||
|
if os.environ.get("IPDB_SHILONG_DEBUG", None) == "INFO":
|
||||||
|
# for debug and visualization only.
|
||||||
|
if "strings_positive" in target:
|
||||||
|
target["strings_positive"] = [
|
||||||
|
_i for _i, _j in zip(target["strings_positive"], keep) if _j
|
||||||
|
]
|
||||||
|
|
||||||
|
return cropped_image, target
|
||||||
|
|
||||||
|
|
||||||
|
def hflip(image, target):
|
||||||
|
flipped_image = F.hflip(image)
|
||||||
|
|
||||||
|
w, h = image.size
|
||||||
|
|
||||||
|
target = target.copy()
|
||||||
|
if "boxes" in target:
|
||||||
|
boxes = target["boxes"]
|
||||||
|
boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor(
|
||||||
|
[w, 0, w, 0]
|
||||||
|
)
|
||||||
|
target["boxes"] = boxes
|
||||||
|
|
||||||
|
if "masks" in target:
|
||||||
|
target["masks"] = target["masks"].flip(-1)
|
||||||
|
|
||||||
|
return flipped_image, target
|
||||||
|
|
||||||
|
|
||||||
|
def resize(image, target, size, max_size=None):
|
||||||
|
# size can be min_size (scalar) or (w, h) tuple
|
||||||
|
|
||||||
|
def get_size_with_aspect_ratio(image_size, size, max_size=None):
|
||||||
|
w, h = image_size
|
||||||
|
if max_size is not None:
|
||||||
|
min_original_size = float(min((w, h)))
|
||||||
|
max_original_size = float(max((w, h)))
|
||||||
|
if max_original_size / min_original_size * size > max_size:
|
||||||
|
size = int(round(max_size * min_original_size / max_original_size))
|
||||||
|
|
||||||
|
if (w <= h and w == size) or (h <= w and h == size):
|
||||||
|
return (h, w)
|
||||||
|
|
||||||
|
if w < h:
|
||||||
|
ow = size
|
||||||
|
oh = int(size * h / w)
|
||||||
|
else:
|
||||||
|
oh = size
|
||||||
|
ow = int(size * w / h)
|
||||||
|
|
||||||
|
return (oh, ow)
|
||||||
|
|
||||||
|
def get_size(image_size, size, max_size=None):
|
||||||
|
if isinstance(size, (list, tuple)):
|
||||||
|
return size[::-1]
|
||||||
|
else:
|
||||||
|
return get_size_with_aspect_ratio(image_size, size, max_size)
|
||||||
|
|
||||||
|
size = get_size(image.size, size, max_size)
|
||||||
|
rescaled_image = F.resize(image, size)
|
||||||
|
|
||||||
|
if target is None:
|
||||||
|
return rescaled_image, None
|
||||||
|
|
||||||
|
ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size))
|
||||||
|
ratio_width, ratio_height = ratios
|
||||||
|
|
||||||
|
target = target.copy()
|
||||||
|
if "boxes" in target:
|
||||||
|
boxes = target["boxes"]
|
||||||
|
scaled_boxes = boxes * torch.as_tensor(
|
||||||
|
[ratio_width, ratio_height, ratio_width, ratio_height]
|
||||||
|
)
|
||||||
|
target["boxes"] = scaled_boxes
|
||||||
|
|
||||||
|
if "area" in target:
|
||||||
|
area = target["area"]
|
||||||
|
scaled_area = area * (ratio_width * ratio_height)
|
||||||
|
target["area"] = scaled_area
|
||||||
|
|
||||||
|
h, w = size
|
||||||
|
target["size"] = torch.tensor([h, w])
|
||||||
|
|
||||||
|
if "masks" in target:
|
||||||
|
target["masks"] = (
|
||||||
|
interpolate(target["masks"][:, None].float(), size, mode="nearest")[:, 0] > 0.5
|
||||||
|
)
|
||||||
|
|
||||||
|
return rescaled_image, target
|
||||||
|
|
||||||
|
|
||||||
|
def pad(image, target, padding):
|
||||||
|
# assumes that we only pad on the bottom right corners
|
||||||
|
padded_image = F.pad(image, (0, 0, padding[0], padding[1]))
|
||||||
|
if target is None:
|
||||||
|
return padded_image, None
|
||||||
|
target = target.copy()
|
||||||
|
# should we do something wrt the original size?
|
||||||
|
target["size"] = torch.tensor(padded_image.size[::-1])
|
||||||
|
if "masks" in target:
|
||||||
|
target["masks"] = torch.nn.functional.pad(target["masks"], (0, padding[0], 0, padding[1]))
|
||||||
|
return padded_image, target
|
||||||
|
|
||||||
|
|
||||||
|
class ResizeDebug(object):
|
||||||
|
def __init__(self, size):
|
||||||
|
self.size = size
|
||||||
|
|
||||||
|
def __call__(self, img, target):
|
||||||
|
return resize(img, target, self.size)
|
||||||
|
|
||||||
|
|
||||||
|
class RandomCrop(object):
|
||||||
|
def __init__(self, size):
|
||||||
|
self.size = size
|
||||||
|
|
||||||
|
def __call__(self, img, target):
|
||||||
|
region = T.RandomCrop.get_params(img, self.size)
|
||||||
|
return crop(img, target, region)
|
||||||
|
|
||||||
|
|
||||||
|
class RandomSizeCrop(object):
|
||||||
|
def __init__(self, min_size: int, max_size: int, respect_boxes: bool = False):
|
||||||
|
# respect_boxes: True to keep all boxes
|
||||||
|
# False to tolerence box filter
|
||||||
|
self.min_size = min_size
|
||||||
|
self.max_size = max_size
|
||||||
|
self.respect_boxes = respect_boxes
|
||||||
|
|
||||||
|
def __call__(self, img: PIL.Image.Image, target: dict):
|
||||||
|
init_boxes = len(target["boxes"])
|
||||||
|
max_patience = 10
|
||||||
|
for i in range(max_patience):
|
||||||
|
w = random.randint(self.min_size, min(img.width, self.max_size))
|
||||||
|
h = random.randint(self.min_size, min(img.height, self.max_size))
|
||||||
|
region = T.RandomCrop.get_params(img, [h, w])
|
||||||
|
result_img, result_target = crop(img, target, region)
|
||||||
|
if (
|
||||||
|
not self.respect_boxes
|
||||||
|
or len(result_target["boxes"]) == init_boxes
|
||||||
|
or i == max_patience - 1
|
||||||
|
):
|
||||||
|
return result_img, result_target
|
||||||
|
return result_img, result_target
|
||||||
|
|
||||||
|
|
||||||
|
class CenterCrop(object):
|
||||||
|
def __init__(self, size):
|
||||||
|
self.size = size
|
||||||
|
|
||||||
|
def __call__(self, img, target):
|
||||||
|
image_width, image_height = img.size
|
||||||
|
crop_height, crop_width = self.size
|
||||||
|
crop_top = int(round((image_height - crop_height) / 2.0))
|
||||||
|
crop_left = int(round((image_width - crop_width) / 2.0))
|
||||||
|
return crop(img, target, (crop_top, crop_left, crop_height, crop_width))
|
||||||
|
|
||||||
|
|
||||||
|
class RandomHorizontalFlip(object):
|
||||||
|
def __init__(self, p=0.5):
|
||||||
|
self.p = p
|
||||||
|
|
||||||
|
def __call__(self, img, target):
|
||||||
|
if random.random() < self.p:
|
||||||
|
return hflip(img, target)
|
||||||
|
return img, target
|
||||||
|
|
||||||
|
|
||||||
|
class RandomResize(object):
|
||||||
|
def __init__(self, sizes, max_size=None):
|
||||||
|
assert isinstance(sizes, (list, tuple))
|
||||||
|
self.sizes = sizes
|
||||||
|
self.max_size = max_size
|
||||||
|
|
||||||
|
def __call__(self, img, target=None):
|
||||||
|
size = random.choice(self.sizes)
|
||||||
|
return resize(img, target, size, self.max_size)
|
||||||
|
|
||||||
|
|
||||||
|
class RandomPad(object):
|
||||||
|
def __init__(self, max_pad):
|
||||||
|
self.max_pad = max_pad
|
||||||
|
|
||||||
|
def __call__(self, img, target):
|
||||||
|
pad_x = random.randint(0, self.max_pad)
|
||||||
|
pad_y = random.randint(0, self.max_pad)
|
||||||
|
return pad(img, target, (pad_x, pad_y))
|
||||||
|
|
||||||
|
|
||||||
|
class RandomSelect(object):
|
||||||
|
"""
|
||||||
|
Randomly selects between transforms1 and transforms2,
|
||||||
|
with probability p for transforms1 and (1 - p) for transforms2
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, transforms1, transforms2, p=0.5):
|
||||||
|
self.transforms1 = transforms1
|
||||||
|
self.transforms2 = transforms2
|
||||||
|
self.p = p
|
||||||
|
|
||||||
|
def __call__(self, img, target):
|
||||||
|
if random.random() < self.p:
|
||||||
|
return self.transforms1(img, target)
|
||||||
|
return self.transforms2(img, target)
|
||||||
|
|
||||||
|
|
||||||
|
class ToTensor(object):
|
||||||
|
def __call__(self, img, target):
|
||||||
|
return F.to_tensor(img), target
|
||||||
|
|
||||||
|
|
||||||
|
class RandomErasing(object):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
self.eraser = T.RandomErasing(*args, **kwargs)
|
||||||
|
|
||||||
|
def __call__(self, img, target):
|
||||||
|
return self.eraser(img), target
|
||||||
|
|
||||||
|
|
||||||
|
class Normalize(object):
|
||||||
|
def __init__(self, mean, std):
|
||||||
|
self.mean = mean
|
||||||
|
self.std = std
|
||||||
|
|
||||||
|
def __call__(self, image, target=None):
|
||||||
|
image = F.normalize(image, mean=self.mean, std=self.std)
|
||||||
|
if target is None:
|
||||||
|
return image, None
|
||||||
|
target = target.copy()
|
||||||
|
h, w = image.shape[-2:]
|
||||||
|
if "boxes" in target:
|
||||||
|
boxes = target["boxes"]
|
||||||
|
boxes = box_xyxy_to_cxcywh(boxes)
|
||||||
|
boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32)
|
||||||
|
target["boxes"] = boxes
|
||||||
|
return image, target
|
||||||
|
|
||||||
|
|
||||||
|
class Compose(object):
|
||||||
|
def __init__(self, transforms):
|
||||||
|
self.transforms = transforms
|
||||||
|
|
||||||
|
def __call__(self, image, target):
|
||||||
|
for t in self.transforms:
|
||||||
|
image, target = t(image, target)
|
||||||
|
return image, target
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
format_string = self.__class__.__name__ + "("
|
||||||
|
for t in self.transforms:
|
||||||
|
format_string += "\n"
|
||||||
|
format_string += " {0}".format(t)
|
||||||
|
format_string += "\n)"
|
||||||
|
return format_string
|
@ -0,0 +1,15 @@
|
|||||||
|
# ------------------------------------------------------------------------
|
||||||
|
# Grounding DINO
|
||||||
|
# url: https://github.com/IDEA-Research/GroundingDINO
|
||||||
|
# Copyright (c) 2023 IDEA. All Rights Reserved.
|
||||||
|
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
||||||
|
# ------------------------------------------------------------------------
|
||||||
|
# Conditional DETR
|
||||||
|
# Copyright (c) 2021 Microsoft. All Rights Reserved.
|
||||||
|
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
||||||
|
# ------------------------------------------------------------------------
|
||||||
|
# Copied from DETR (https://github.com/facebookresearch/detr)
|
||||||
|
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
||||||
|
# ------------------------------------------------------------------------
|
||||||
|
|
||||||
|
from .groundingdino import build_groundingdino
|
@ -0,0 +1 @@
|
|||||||
|
from .backbone import build_backbone
|
@ -0,0 +1,221 @@
|
|||||||
|
# ------------------------------------------------------------------------
|
||||||
|
# Grounding DINO
|
||||||
|
# url: https://github.com/IDEA-Research/GroundingDINO
|
||||||
|
# Copyright (c) 2023 IDEA. All Rights Reserved.
|
||||||
|
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
||||||
|
# ------------------------------------------------------------------------
|
||||||
|
# Conditional DETR
|
||||||
|
# Copyright (c) 2021 Microsoft. All Rights Reserved.
|
||||||
|
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
||||||
|
# ------------------------------------------------------------------------
|
||||||
|
# Copied from DETR (https://github.com/facebookresearch/detr)
|
||||||
|
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
||||||
|
# ------------------------------------------------------------------------
|
||||||
|
|
||||||
|
"""
|
||||||
|
Backbone modules.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torchvision
|
||||||
|
from torch import nn
|
||||||
|
from torchvision.models._utils import IntermediateLayerGetter
|
||||||
|
|
||||||
|
from groundingdino.util.misc import NestedTensor, clean_state_dict, is_main_process
|
||||||
|
|
||||||
|
from .position_encoding import build_position_encoding
|
||||||
|
from .swin_transformer import build_swin_transformer
|
||||||
|
|
||||||
|
|
||||||
|
class FrozenBatchNorm2d(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
BatchNorm2d where the batch statistics and the affine parameters are fixed.
|
||||||
|
|
||||||
|
Copy-paste from torchvision.misc.ops with added eps before rqsrt,
|
||||||
|
without which any other models than torchvision.models.resnet[18,34,50,101]
|
||||||
|
produce nans.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, n):
|
||||||
|
super(FrozenBatchNorm2d, self).__init__()
|
||||||
|
self.register_buffer("weight", torch.ones(n))
|
||||||
|
self.register_buffer("bias", torch.zeros(n))
|
||||||
|
self.register_buffer("running_mean", torch.zeros(n))
|
||||||
|
self.register_buffer("running_var", torch.ones(n))
|
||||||
|
|
||||||
|
def _load_from_state_dict(
|
||||||
|
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
||||||
|
):
|
||||||
|
num_batches_tracked_key = prefix + "num_batches_tracked"
|
||||||
|
if num_batches_tracked_key in state_dict:
|
||||||
|
del state_dict[num_batches_tracked_key]
|
||||||
|
|
||||||
|
super(FrozenBatchNorm2d, self)._load_from_state_dict(
|
||||||
|
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# move reshapes to the beginning
|
||||||
|
# to make it fuser-friendly
|
||||||
|
w = self.weight.reshape(1, -1, 1, 1)
|
||||||
|
b = self.bias.reshape(1, -1, 1, 1)
|
||||||
|
rv = self.running_var.reshape(1, -1, 1, 1)
|
||||||
|
rm = self.running_mean.reshape(1, -1, 1, 1)
|
||||||
|
eps = 1e-5
|
||||||
|
scale = w * (rv + eps).rsqrt()
|
||||||
|
bias = b - rm * scale
|
||||||
|
return x * scale + bias
|
||||||
|
|
||||||
|
|
||||||
|
class BackboneBase(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
backbone: nn.Module,
|
||||||
|
train_backbone: bool,
|
||||||
|
num_channels: int,
|
||||||
|
return_interm_indices: list,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
for name, parameter in backbone.named_parameters():
|
||||||
|
if (
|
||||||
|
not train_backbone
|
||||||
|
or "layer2" not in name
|
||||||
|
and "layer3" not in name
|
||||||
|
and "layer4" not in name
|
||||||
|
):
|
||||||
|
parameter.requires_grad_(False)
|
||||||
|
|
||||||
|
return_layers = {}
|
||||||
|
for idx, layer_index in enumerate(return_interm_indices):
|
||||||
|
return_layers.update(
|
||||||
|
{"layer{}".format(5 - len(return_interm_indices) + idx): "{}".format(layer_index)}
|
||||||
|
)
|
||||||
|
|
||||||
|
# if len:
|
||||||
|
# if use_stage1_feature:
|
||||||
|
# return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
|
||||||
|
# else:
|
||||||
|
# return_layers = {"layer2": "0", "layer3": "1", "layer4": "2"}
|
||||||
|
# else:
|
||||||
|
# return_layers = {'layer4': "0"}
|
||||||
|
self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
|
||||||
|
self.num_channels = num_channels
|
||||||
|
|
||||||
|
def forward(self, tensor_list: NestedTensor):
|
||||||
|
xs = self.body(tensor_list.tensors)
|
||||||
|
out: Dict[str, NestedTensor] = {}
|
||||||
|
for name, x in xs.items():
|
||||||
|
m = tensor_list.mask
|
||||||
|
assert m is not None
|
||||||
|
mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
|
||||||
|
out[name] = NestedTensor(x, mask)
|
||||||
|
# import ipdb; ipdb.set_trace()
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class Backbone(BackboneBase):
|
||||||
|
"""ResNet backbone with frozen BatchNorm."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
train_backbone: bool,
|
||||||
|
dilation: bool,
|
||||||
|
return_interm_indices: list,
|
||||||
|
batch_norm=FrozenBatchNorm2d,
|
||||||
|
):
|
||||||
|
if name in ["resnet18", "resnet34", "resnet50", "resnet101"]:
|
||||||
|
backbone = getattr(torchvision.models, name)(
|
||||||
|
replace_stride_with_dilation=[False, False, dilation],
|
||||||
|
pretrained=is_main_process(),
|
||||||
|
norm_layer=batch_norm,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Why you can get here with name {}".format(name))
|
||||||
|
# num_channels = 512 if name in ('resnet18', 'resnet34') else 2048
|
||||||
|
assert name not in ("resnet18", "resnet34"), "Only resnet50 and resnet101 are available."
|
||||||
|
assert return_interm_indices in [[0, 1, 2, 3], [1, 2, 3], [3]]
|
||||||
|
num_channels_all = [256, 512, 1024, 2048]
|
||||||
|
num_channels = num_channels_all[4 - len(return_interm_indices) :]
|
||||||
|
super().__init__(backbone, train_backbone, num_channels, return_interm_indices)
|
||||||
|
|
||||||
|
|
||||||
|
class Joiner(nn.Sequential):
|
||||||
|
def __init__(self, backbone, position_embedding):
|
||||||
|
super().__init__(backbone, position_embedding)
|
||||||
|
|
||||||
|
def forward(self, tensor_list: NestedTensor):
|
||||||
|
xs = self[0](tensor_list)
|
||||||
|
out: List[NestedTensor] = []
|
||||||
|
pos = []
|
||||||
|
for name, x in xs.items():
|
||||||
|
out.append(x)
|
||||||
|
# position encoding
|
||||||
|
pos.append(self[1](x).to(x.tensors.dtype))
|
||||||
|
|
||||||
|
return out, pos
|
||||||
|
|
||||||
|
|
||||||
|
def build_backbone(args):
|
||||||
|
"""
|
||||||
|
Useful args:
|
||||||
|
- backbone: backbone name
|
||||||
|
- lr_backbone:
|
||||||
|
- dilation
|
||||||
|
- return_interm_indices: available: [0,1,2,3], [1,2,3], [3]
|
||||||
|
- backbone_freeze_keywords:
|
||||||
|
- use_checkpoint: for swin only for now
|
||||||
|
|
||||||
|
"""
|
||||||
|
position_embedding = build_position_encoding(args)
|
||||||
|
train_backbone = True
|
||||||
|
if not train_backbone:
|
||||||
|
raise ValueError("Please set lr_backbone > 0")
|
||||||
|
return_interm_indices = args.return_interm_indices
|
||||||
|
assert return_interm_indices in [[0, 1, 2, 3], [1, 2, 3], [3]]
|
||||||
|
args.backbone_freeze_keywords
|
||||||
|
use_checkpoint = getattr(args, "use_checkpoint", False)
|
||||||
|
|
||||||
|
if args.backbone in ["resnet50", "resnet101"]:
|
||||||
|
backbone = Backbone(
|
||||||
|
args.backbone,
|
||||||
|
train_backbone,
|
||||||
|
args.dilation,
|
||||||
|
return_interm_indices,
|
||||||
|
batch_norm=FrozenBatchNorm2d,
|
||||||
|
)
|
||||||
|
bb_num_channels = backbone.num_channels
|
||||||
|
elif args.backbone in [
|
||||||
|
"swin_T_224_1k",
|
||||||
|
"swin_B_224_22k",
|
||||||
|
"swin_B_384_22k",
|
||||||
|
"swin_L_224_22k",
|
||||||
|
"swin_L_384_22k",
|
||||||
|
]:
|
||||||
|
pretrain_img_size = int(args.backbone.split("_")[-2])
|
||||||
|
backbone = build_swin_transformer(
|
||||||
|
args.backbone,
|
||||||
|
pretrain_img_size=pretrain_img_size,
|
||||||
|
out_indices=tuple(return_interm_indices),
|
||||||
|
dilation=False,
|
||||||
|
use_checkpoint=use_checkpoint,
|
||||||
|
)
|
||||||
|
|
||||||
|
bb_num_channels = backbone.num_features[4 - len(return_interm_indices) :]
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Unknown backbone {}".format(args.backbone))
|
||||||
|
|
||||||
|
assert len(bb_num_channels) == len(
|
||||||
|
return_interm_indices
|
||||||
|
), f"len(bb_num_channels) {len(bb_num_channels)} != len(return_interm_indices) {len(return_interm_indices)}"
|
||||||
|
|
||||||
|
model = Joiner(backbone, position_embedding)
|
||||||
|
model.num_channels = bb_num_channels
|
||||||
|
assert isinstance(
|
||||||
|
bb_num_channels, List
|
||||||
|
), "bb_num_channels is expected to be a List but {}".format(type(bb_num_channels))
|
||||||
|
# import ipdb; ipdb.set_trace()
|
||||||
|
return model
|
@ -0,0 +1,186 @@
|
|||||||
|
# ------------------------------------------------------------------------
|
||||||
|
# Grounding DINO
|
||||||
|
# url: https://github.com/IDEA-Research/GroundingDINO
|
||||||
|
# Copyright (c) 2023 IDEA. All Rights Reserved.
|
||||||
|
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
||||||
|
# ------------------------------------------------------------------------
|
||||||
|
# DINO
|
||||||
|
# Copyright (c) 2022 IDEA. All Rights Reserved.
|
||||||
|
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
||||||
|
# ------------------------------------------------------------------------
|
||||||
|
# Conditional DETR
|
||||||
|
# Copyright (c) 2021 Microsoft. All Rights Reserved.
|
||||||
|
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
||||||
|
# ------------------------------------------------------------------------
|
||||||
|
# Copied from DETR (https://github.com/facebookresearch/detr)
|
||||||
|
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
||||||
|
# ------------------------------------------------------------------------
|
||||||
|
|
||||||
|
"""
|
||||||
|
Various positional encodings for the transformer.
|
||||||
|
"""
|
||||||
|
import math
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from groundingdino.util.misc import NestedTensor
|
||||||
|
|
||||||
|
|
||||||
|
class PositionEmbeddingSine(nn.Module):
|
||||||
|
"""
|
||||||
|
This is a more standard version of the position embedding, very similar to the one
|
||||||
|
used by the Attention is all you need paper, generalized to work on images.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
|
||||||
|
super().__init__()
|
||||||
|
self.num_pos_feats = num_pos_feats
|
||||||
|
self.temperature = temperature
|
||||||
|
self.normalize = normalize
|
||||||
|
if scale is not None and normalize is False:
|
||||||
|
raise ValueError("normalize should be True if scale is passed")
|
||||||
|
if scale is None:
|
||||||
|
scale = 2 * math.pi
|
||||||
|
self.scale = scale
|
||||||
|
|
||||||
|
def forward(self, tensor_list: NestedTensor):
|
||||||
|
x = tensor_list.tensors
|
||||||
|
mask = tensor_list.mask
|
||||||
|
assert mask is not None
|
||||||
|
not_mask = ~mask
|
||||||
|
y_embed = not_mask.cumsum(1, dtype=torch.float32)
|
||||||
|
x_embed = not_mask.cumsum(2, dtype=torch.float32)
|
||||||
|
if self.normalize:
|
||||||
|
eps = 1e-6
|
||||||
|
# if os.environ.get("SHILONG_AMP", None) == '1':
|
||||||
|
# eps = 1e-4
|
||||||
|
# else:
|
||||||
|
# eps = 1e-6
|
||||||
|
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
||||||
|
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
||||||
|
|
||||||
|
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
||||||
|
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
||||||
|
|
||||||
|
pos_x = x_embed[:, :, :, None] / dim_t
|
||||||
|
pos_y = y_embed[:, :, :, None] / dim_t
|
||||||
|
pos_x = torch.stack(
|
||||||
|
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
|
||||||
|
).flatten(3)
|
||||||
|
pos_y = torch.stack(
|
||||||
|
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
|
||||||
|
).flatten(3)
|
||||||
|
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
||||||
|
return pos
|
||||||
|
|
||||||
|
|
||||||
|
class PositionEmbeddingSineHW(nn.Module):
|
||||||
|
"""
|
||||||
|
This is a more standard version of the position embedding, very similar to the one
|
||||||
|
used by the Attention is all you need paper, generalized to work on images.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, num_pos_feats=64, temperatureH=10000, temperatureW=10000, normalize=False, scale=None
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.num_pos_feats = num_pos_feats
|
||||||
|
self.temperatureH = temperatureH
|
||||||
|
self.temperatureW = temperatureW
|
||||||
|
self.normalize = normalize
|
||||||
|
if scale is not None and normalize is False:
|
||||||
|
raise ValueError("normalize should be True if scale is passed")
|
||||||
|
if scale is None:
|
||||||
|
scale = 2 * math.pi
|
||||||
|
self.scale = scale
|
||||||
|
|
||||||
|
def forward(self, tensor_list: NestedTensor):
|
||||||
|
x = tensor_list.tensors
|
||||||
|
mask = tensor_list.mask
|
||||||
|
assert mask is not None
|
||||||
|
not_mask = ~mask
|
||||||
|
y_embed = not_mask.cumsum(1, dtype=torch.float32)
|
||||||
|
x_embed = not_mask.cumsum(2, dtype=torch.float32)
|
||||||
|
|
||||||
|
# import ipdb; ipdb.set_trace()
|
||||||
|
|
||||||
|
if self.normalize:
|
||||||
|
eps = 1e-6
|
||||||
|
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
||||||
|
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
||||||
|
|
||||||
|
dim_tx = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
||||||
|
dim_tx = self.temperatureW ** (2 * (torch.div(dim_tx, 2, rounding_mode='floor')) / self.num_pos_feats)
|
||||||
|
pos_x = x_embed[:, :, :, None] / dim_tx
|
||||||
|
|
||||||
|
dim_ty = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
||||||
|
dim_ty = self.temperatureH ** (2 * (torch.div(dim_ty, 2, rounding_mode='floor')) / self.num_pos_feats)
|
||||||
|
pos_y = y_embed[:, :, :, None] / dim_ty
|
||||||
|
|
||||||
|
pos_x = torch.stack(
|
||||||
|
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
|
||||||
|
).flatten(3)
|
||||||
|
pos_y = torch.stack(
|
||||||
|
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
|
||||||
|
).flatten(3)
|
||||||
|
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
||||||
|
|
||||||
|
# import ipdb; ipdb.set_trace()
|
||||||
|
|
||||||
|
return pos
|
||||||
|
|
||||||
|
|
||||||
|
class PositionEmbeddingLearned(nn.Module):
|
||||||
|
"""
|
||||||
|
Absolute pos embedding, learned.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, num_pos_feats=256):
|
||||||
|
super().__init__()
|
||||||
|
self.row_embed = nn.Embedding(50, num_pos_feats)
|
||||||
|
self.col_embed = nn.Embedding(50, num_pos_feats)
|
||||||
|
self.reset_parameters()
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
nn.init.uniform_(self.row_embed.weight)
|
||||||
|
nn.init.uniform_(self.col_embed.weight)
|
||||||
|
|
||||||
|
def forward(self, tensor_list: NestedTensor):
|
||||||
|
x = tensor_list.tensors
|
||||||
|
h, w = x.shape[-2:]
|
||||||
|
i = torch.arange(w, device=x.device)
|
||||||
|
j = torch.arange(h, device=x.device)
|
||||||
|
x_emb = self.col_embed(i)
|
||||||
|
y_emb = self.row_embed(j)
|
||||||
|
pos = (
|
||||||
|
torch.cat(
|
||||||
|
[
|
||||||
|
x_emb.unsqueeze(0).repeat(h, 1, 1),
|
||||||
|
y_emb.unsqueeze(1).repeat(1, w, 1),
|
||||||
|
],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
.permute(2, 0, 1)
|
||||||
|
.unsqueeze(0)
|
||||||
|
.repeat(x.shape[0], 1, 1, 1)
|
||||||
|
)
|
||||||
|
return pos
|
||||||
|
|
||||||
|
|
||||||
|
def build_position_encoding(args):
|
||||||
|
N_steps = args.hidden_dim // 2
|
||||||
|
if args.position_embedding in ("v2", "sine"):
|
||||||
|
# TODO find a better way of exposing other arguments
|
||||||
|
position_embedding = PositionEmbeddingSineHW(
|
||||||
|
N_steps,
|
||||||
|
temperatureH=args.pe_temperatureH,
|
||||||
|
temperatureW=args.pe_temperatureW,
|
||||||
|
normalize=True,
|
||||||
|
)
|
||||||
|
elif args.position_embedding in ("v3", "learned"):
|
||||||
|
position_embedding = PositionEmbeddingLearned(N_steps)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"not supported {args.position_embedding}")
|
||||||
|
|
||||||
|
return position_embedding
|
@ -0,0 +1,802 @@
|
|||||||
|
# ------------------------------------------------------------------------
|
||||||
|
# Grounding DINO
|
||||||
|
# url: https://github.com/IDEA-Research/GroundingDINO
|
||||||
|
# Copyright (c) 2023 IDEA. All Rights Reserved.
|
||||||
|
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
||||||
|
# ------------------------------------------------------------------------
|
||||||
|
# DINO
|
||||||
|
# Copyright (c) 2022 IDEA. All Rights Reserved.
|
||||||
|
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
||||||
|
# --------------------------------------------------------
|
||||||
|
# modified from https://github.com/SwinTransformer/Swin-Transformer-Object-Detection/blob/master/mmdet/models/backbones/swin_transformer.py
|
||||||
|
# --------------------------------------------------------
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch.utils.checkpoint as checkpoint
|
||||||
|
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
||||||
|
|
||||||
|
from groundingdino.util.misc import NestedTensor
|
||||||
|
|
||||||
|
|
||||||
|
class Mlp(nn.Module):
|
||||||
|
"""Multilayer perceptron."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
out_features = out_features or in_features
|
||||||
|
hidden_features = hidden_features or in_features
|
||||||
|
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||||
|
self.act = act_layer()
|
||||||
|
self.fc2 = nn.Linear(hidden_features, out_features)
|
||||||
|
self.drop = nn.Dropout(drop)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.fc1(x)
|
||||||
|
x = self.act(x)
|
||||||
|
x = self.drop(x)
|
||||||
|
x = self.fc2(x)
|
||||||
|
x = self.drop(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def window_partition(x, window_size):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x: (B, H, W, C)
|
||||||
|
window_size (int): window size
|
||||||
|
Returns:
|
||||||
|
windows: (num_windows*B, window_size, window_size, C)
|
||||||
|
"""
|
||||||
|
B, H, W, C = x.shape
|
||||||
|
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
|
||||||
|
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
||||||
|
return windows
|
||||||
|
|
||||||
|
|
||||||
|
def window_reverse(windows, window_size, H, W):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
windows: (num_windows*B, window_size, window_size, C)
|
||||||
|
window_size (int): Window size
|
||||||
|
H (int): Height of image
|
||||||
|
W (int): Width of image
|
||||||
|
Returns:
|
||||||
|
x: (B, H, W, C)
|
||||||
|
"""
|
||||||
|
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
||||||
|
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
|
||||||
|
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class WindowAttention(nn.Module):
|
||||||
|
"""Window based multi-head self attention (W-MSA) module with relative position bias.
|
||||||
|
It supports both of shifted and non-shifted window.
|
||||||
|
Args:
|
||||||
|
dim (int): Number of input channels.
|
||||||
|
window_size (tuple[int]): The height and width of the window.
|
||||||
|
num_heads (int): Number of attention heads.
|
||||||
|
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
||||||
|
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
|
||||||
|
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
||||||
|
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
window_size,
|
||||||
|
num_heads,
|
||||||
|
qkv_bias=True,
|
||||||
|
qk_scale=None,
|
||||||
|
attn_drop=0.0,
|
||||||
|
proj_drop=0.0,
|
||||||
|
):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.window_size = window_size # Wh, Ww
|
||||||
|
self.num_heads = num_heads
|
||||||
|
head_dim = dim // num_heads
|
||||||
|
self.scale = qk_scale or head_dim**-0.5
|
||||||
|
|
||||||
|
# define a parameter table of relative position bias
|
||||||
|
self.relative_position_bias_table = nn.Parameter(
|
||||||
|
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
|
||||||
|
) # 2*Wh-1 * 2*Ww-1, nH
|
||||||
|
|
||||||
|
# get pair-wise relative position index for each token inside the window
|
||||||
|
coords_h = torch.arange(self.window_size[0])
|
||||||
|
coords_w = torch.arange(self.window_size[1])
|
||||||
|
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
||||||
|
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
||||||
|
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
||||||
|
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
||||||
|
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
|
||||||
|
relative_coords[:, :, 1] += self.window_size[1] - 1
|
||||||
|
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
||||||
|
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
||||||
|
self.register_buffer("relative_position_index", relative_position_index)
|
||||||
|
|
||||||
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||||
|
self.attn_drop = nn.Dropout(attn_drop)
|
||||||
|
self.proj = nn.Linear(dim, dim)
|
||||||
|
self.proj_drop = nn.Dropout(proj_drop)
|
||||||
|
|
||||||
|
trunc_normal_(self.relative_position_bias_table, std=0.02)
|
||||||
|
self.softmax = nn.Softmax(dim=-1)
|
||||||
|
|
||||||
|
def forward(self, x, mask=None):
|
||||||
|
"""Forward function.
|
||||||
|
Args:
|
||||||
|
x: input features with shape of (num_windows*B, N, C)
|
||||||
|
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
|
||||||
|
"""
|
||||||
|
B_, N, C = x.shape
|
||||||
|
qkv = (
|
||||||
|
self.qkv(x)
|
||||||
|
.reshape(B_, N, 3, self.num_heads, C // self.num_heads)
|
||||||
|
.permute(2, 0, 3, 1, 4)
|
||||||
|
)
|
||||||
|
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
||||||
|
|
||||||
|
q = q * self.scale
|
||||||
|
attn = q @ k.transpose(-2, -1)
|
||||||
|
|
||||||
|
relative_position_bias = self.relative_position_bias_table[
|
||||||
|
self.relative_position_index.view(-1)
|
||||||
|
].view(
|
||||||
|
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1
|
||||||
|
) # Wh*Ww,Wh*Ww,nH
|
||||||
|
relative_position_bias = relative_position_bias.permute(
|
||||||
|
2, 0, 1
|
||||||
|
).contiguous() # nH, Wh*Ww, Wh*Ww
|
||||||
|
attn = attn + relative_position_bias.unsqueeze(0)
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
nW = mask.shape[0]
|
||||||
|
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
|
||||||
|
attn = attn.view(-1, self.num_heads, N, N)
|
||||||
|
attn = self.softmax(attn)
|
||||||
|
else:
|
||||||
|
attn = self.softmax(attn)
|
||||||
|
|
||||||
|
attn = self.attn_drop(attn)
|
||||||
|
|
||||||
|
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
||||||
|
x = self.proj(x)
|
||||||
|
x = self.proj_drop(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SwinTransformerBlock(nn.Module):
|
||||||
|
"""Swin Transformer Block.
|
||||||
|
Args:
|
||||||
|
dim (int): Number of input channels.
|
||||||
|
num_heads (int): Number of attention heads.
|
||||||
|
window_size (int): Window size.
|
||||||
|
shift_size (int): Shift size for SW-MSA.
|
||||||
|
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
||||||
|
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
||||||
|
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
||||||
|
drop (float, optional): Dropout rate. Default: 0.0
|
||||||
|
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
||||||
|
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
||||||
|
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
||||||
|
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
num_heads,
|
||||||
|
window_size=7,
|
||||||
|
shift_size=0,
|
||||||
|
mlp_ratio=4.0,
|
||||||
|
qkv_bias=True,
|
||||||
|
qk_scale=None,
|
||||||
|
drop=0.0,
|
||||||
|
attn_drop=0.0,
|
||||||
|
drop_path=0.0,
|
||||||
|
act_layer=nn.GELU,
|
||||||
|
norm_layer=nn.LayerNorm,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.window_size = window_size
|
||||||
|
self.shift_size = shift_size
|
||||||
|
self.mlp_ratio = mlp_ratio
|
||||||
|
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
|
||||||
|
|
||||||
|
self.norm1 = norm_layer(dim)
|
||||||
|
self.attn = WindowAttention(
|
||||||
|
dim,
|
||||||
|
window_size=to_2tuple(self.window_size),
|
||||||
|
num_heads=num_heads,
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
qk_scale=qk_scale,
|
||||||
|
attn_drop=attn_drop,
|
||||||
|
proj_drop=drop,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||||
|
self.norm2 = norm_layer(dim)
|
||||||
|
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||||
|
self.mlp = Mlp(
|
||||||
|
in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop
|
||||||
|
)
|
||||||
|
|
||||||
|
self.H = None
|
||||||
|
self.W = None
|
||||||
|
|
||||||
|
def forward(self, x, mask_matrix):
|
||||||
|
"""Forward function.
|
||||||
|
Args:
|
||||||
|
x: Input feature, tensor size (B, H*W, C).
|
||||||
|
H, W: Spatial resolution of the input feature.
|
||||||
|
mask_matrix: Attention mask for cyclic shift.
|
||||||
|
"""
|
||||||
|
B, L, C = x.shape
|
||||||
|
H, W = self.H, self.W
|
||||||
|
assert L == H * W, "input feature has wrong size"
|
||||||
|
|
||||||
|
shortcut = x
|
||||||
|
x = self.norm1(x)
|
||||||
|
x = x.view(B, H, W, C)
|
||||||
|
|
||||||
|
# pad feature maps to multiples of window size
|
||||||
|
pad_l = pad_t = 0
|
||||||
|
pad_r = (self.window_size - W % self.window_size) % self.window_size
|
||||||
|
pad_b = (self.window_size - H % self.window_size) % self.window_size
|
||||||
|
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
|
||||||
|
_, Hp, Wp, _ = x.shape
|
||||||
|
|
||||||
|
# cyclic shift
|
||||||
|
if self.shift_size > 0:
|
||||||
|
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
||||||
|
attn_mask = mask_matrix
|
||||||
|
else:
|
||||||
|
shifted_x = x
|
||||||
|
attn_mask = None
|
||||||
|
|
||||||
|
# partition windows
|
||||||
|
x_windows = window_partition(
|
||||||
|
shifted_x, self.window_size
|
||||||
|
) # nW*B, window_size, window_size, C
|
||||||
|
x_windows = x_windows.view(
|
||||||
|
-1, self.window_size * self.window_size, C
|
||||||
|
) # nW*B, window_size*window_size, C
|
||||||
|
|
||||||
|
# W-MSA/SW-MSA
|
||||||
|
attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
|
||||||
|
|
||||||
|
# merge windows
|
||||||
|
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
|
||||||
|
shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
|
||||||
|
|
||||||
|
# reverse cyclic shift
|
||||||
|
if self.shift_size > 0:
|
||||||
|
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
||||||
|
else:
|
||||||
|
x = shifted_x
|
||||||
|
|
||||||
|
if pad_r > 0 or pad_b > 0:
|
||||||
|
x = x[:, :H, :W, :].contiguous()
|
||||||
|
|
||||||
|
x = x.view(B, H * W, C)
|
||||||
|
|
||||||
|
# FFN
|
||||||
|
x = shortcut + self.drop_path(x)
|
||||||
|
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class PatchMerging(nn.Module):
|
||||||
|
"""Patch Merging Layer
|
||||||
|
Args:
|
||||||
|
dim (int): Number of input channels.
|
||||||
|
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dim, norm_layer=nn.LayerNorm):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
||||||
|
self.norm = norm_layer(4 * dim)
|
||||||
|
|
||||||
|
def forward(self, x, H, W):
|
||||||
|
"""Forward function.
|
||||||
|
Args:
|
||||||
|
x: Input feature, tensor size (B, H*W, C).
|
||||||
|
H, W: Spatial resolution of the input feature.
|
||||||
|
"""
|
||||||
|
B, L, C = x.shape
|
||||||
|
assert L == H * W, "input feature has wrong size"
|
||||||
|
|
||||||
|
x = x.view(B, H, W, C)
|
||||||
|
|
||||||
|
# padding
|
||||||
|
pad_input = (H % 2 == 1) or (W % 2 == 1)
|
||||||
|
if pad_input:
|
||||||
|
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
|
||||||
|
|
||||||
|
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
|
||||||
|
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
|
||||||
|
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
|
||||||
|
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
|
||||||
|
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
|
||||||
|
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
|
||||||
|
|
||||||
|
x = self.norm(x)
|
||||||
|
x = self.reduction(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class BasicLayer(nn.Module):
|
||||||
|
"""A basic Swin Transformer layer for one stage.
|
||||||
|
Args:
|
||||||
|
dim (int): Number of feature channels
|
||||||
|
depth (int): Depths of this stage.
|
||||||
|
num_heads (int): Number of attention head.
|
||||||
|
window_size (int): Local window size. Default: 7.
|
||||||
|
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
|
||||||
|
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
||||||
|
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
||||||
|
drop (float, optional): Dropout rate. Default: 0.0
|
||||||
|
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
||||||
|
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
||||||
|
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
||||||
|
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
||||||
|
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
depth,
|
||||||
|
num_heads,
|
||||||
|
window_size=7,
|
||||||
|
mlp_ratio=4.0,
|
||||||
|
qkv_bias=True,
|
||||||
|
qk_scale=None,
|
||||||
|
drop=0.0,
|
||||||
|
attn_drop=0.0,
|
||||||
|
drop_path=0.0,
|
||||||
|
norm_layer=nn.LayerNorm,
|
||||||
|
downsample=None,
|
||||||
|
use_checkpoint=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.window_size = window_size
|
||||||
|
self.shift_size = window_size // 2
|
||||||
|
self.depth = depth
|
||||||
|
self.use_checkpoint = use_checkpoint
|
||||||
|
|
||||||
|
# build blocks
|
||||||
|
self.blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
SwinTransformerBlock(
|
||||||
|
dim=dim,
|
||||||
|
num_heads=num_heads,
|
||||||
|
window_size=window_size,
|
||||||
|
shift_size=0 if (i % 2 == 0) else window_size // 2,
|
||||||
|
mlp_ratio=mlp_ratio,
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
qk_scale=qk_scale,
|
||||||
|
drop=drop,
|
||||||
|
attn_drop=attn_drop,
|
||||||
|
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
)
|
||||||
|
for i in range(depth)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# patch merging layer
|
||||||
|
if downsample is not None:
|
||||||
|
self.downsample = downsample(dim=dim, norm_layer=norm_layer)
|
||||||
|
else:
|
||||||
|
self.downsample = None
|
||||||
|
|
||||||
|
def forward(self, x, H, W):
|
||||||
|
"""Forward function.
|
||||||
|
Args:
|
||||||
|
x: Input feature, tensor size (B, H*W, C).
|
||||||
|
H, W: Spatial resolution of the input feature.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# calculate attention mask for SW-MSA
|
||||||
|
Hp = int(np.ceil(H / self.window_size)) * self.window_size
|
||||||
|
Wp = int(np.ceil(W / self.window_size)) * self.window_size
|
||||||
|
img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
|
||||||
|
h_slices = (
|
||||||
|
slice(0, -self.window_size),
|
||||||
|
slice(-self.window_size, -self.shift_size),
|
||||||
|
slice(-self.shift_size, None),
|
||||||
|
)
|
||||||
|
w_slices = (
|
||||||
|
slice(0, -self.window_size),
|
||||||
|
slice(-self.window_size, -self.shift_size),
|
||||||
|
slice(-self.shift_size, None),
|
||||||
|
)
|
||||||
|
cnt = 0
|
||||||
|
for h in h_slices:
|
||||||
|
for w in w_slices:
|
||||||
|
img_mask[:, h, w, :] = cnt
|
||||||
|
cnt += 1
|
||||||
|
|
||||||
|
mask_windows = window_partition(
|
||||||
|
img_mask, self.window_size
|
||||||
|
) # nW, window_size, window_size, 1
|
||||||
|
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
||||||
|
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
||||||
|
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
|
||||||
|
attn_mask == 0, float(0.0)
|
||||||
|
)
|
||||||
|
|
||||||
|
for blk in self.blocks:
|
||||||
|
blk.H, blk.W = H, W
|
||||||
|
if self.use_checkpoint:
|
||||||
|
x = checkpoint.checkpoint(blk, x, attn_mask)
|
||||||
|
else:
|
||||||
|
x = blk(x, attn_mask)
|
||||||
|
if self.downsample is not None:
|
||||||
|
x_down = self.downsample(x, H, W)
|
||||||
|
Wh, Ww = (H + 1) // 2, (W + 1) // 2
|
||||||
|
return x, H, W, x_down, Wh, Ww
|
||||||
|
else:
|
||||||
|
return x, H, W, x, H, W
|
||||||
|
|
||||||
|
|
||||||
|
class PatchEmbed(nn.Module):
|
||||||
|
"""Image to Patch Embedding
|
||||||
|
Args:
|
||||||
|
patch_size (int): Patch token size. Default: 4.
|
||||||
|
in_chans (int): Number of input image channels. Default: 3.
|
||||||
|
embed_dim (int): Number of linear projection output channels. Default: 96.
|
||||||
|
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
|
||||||
|
super().__init__()
|
||||||
|
patch_size = to_2tuple(patch_size)
|
||||||
|
self.patch_size = patch_size
|
||||||
|
|
||||||
|
self.in_chans = in_chans
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
|
||||||
|
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
||||||
|
if norm_layer is not None:
|
||||||
|
self.norm = norm_layer(embed_dim)
|
||||||
|
else:
|
||||||
|
self.norm = None
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""Forward function."""
|
||||||
|
# padding
|
||||||
|
_, _, H, W = x.size()
|
||||||
|
if W % self.patch_size[1] != 0:
|
||||||
|
x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
|
||||||
|
if H % self.patch_size[0] != 0:
|
||||||
|
x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
|
||||||
|
|
||||||
|
x = self.proj(x) # B C Wh Ww
|
||||||
|
if self.norm is not None:
|
||||||
|
Wh, Ww = x.size(2), x.size(3)
|
||||||
|
x = x.flatten(2).transpose(1, 2)
|
||||||
|
x = self.norm(x)
|
||||||
|
x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SwinTransformer(nn.Module):
|
||||||
|
"""Swin Transformer backbone.
|
||||||
|
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
|
||||||
|
https://arxiv.org/pdf/2103.14030
|
||||||
|
Args:
|
||||||
|
pretrain_img_size (int): Input image size for training the pretrained model,
|
||||||
|
used in absolute postion embedding. Default 224.
|
||||||
|
patch_size (int | tuple(int)): Patch size. Default: 4.
|
||||||
|
in_chans (int): Number of input image channels. Default: 3.
|
||||||
|
embed_dim (int): Number of linear projection output channels. Default: 96.
|
||||||
|
depths (tuple[int]): Depths of each Swin Transformer stage.
|
||||||
|
num_heads (tuple[int]): Number of attention head of each stage.
|
||||||
|
window_size (int): Window size. Default: 7.
|
||||||
|
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
|
||||||
|
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
|
||||||
|
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
|
||||||
|
drop_rate (float): Dropout rate.
|
||||||
|
attn_drop_rate (float): Attention dropout rate. Default: 0.
|
||||||
|
drop_path_rate (float): Stochastic depth rate. Default: 0.2.
|
||||||
|
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
|
||||||
|
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
|
||||||
|
patch_norm (bool): If True, add normalization after patch embedding. Default: True.
|
||||||
|
out_indices (Sequence[int]): Output from which stages.
|
||||||
|
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
|
||||||
|
-1 means not freezing any parameters.
|
||||||
|
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
||||||
|
dilation (bool): if True, the output size if 16x downsample, ow 32x downsample.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
pretrain_img_size=224,
|
||||||
|
patch_size=4,
|
||||||
|
in_chans=3,
|
||||||
|
embed_dim=96,
|
||||||
|
depths=[2, 2, 6, 2],
|
||||||
|
num_heads=[3, 6, 12, 24],
|
||||||
|
window_size=7,
|
||||||
|
mlp_ratio=4.0,
|
||||||
|
qkv_bias=True,
|
||||||
|
qk_scale=None,
|
||||||
|
drop_rate=0.0,
|
||||||
|
attn_drop_rate=0.0,
|
||||||
|
drop_path_rate=0.2,
|
||||||
|
norm_layer=nn.LayerNorm,
|
||||||
|
ape=False,
|
||||||
|
patch_norm=True,
|
||||||
|
out_indices=(0, 1, 2, 3),
|
||||||
|
frozen_stages=-1,
|
||||||
|
dilation=False,
|
||||||
|
use_checkpoint=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.pretrain_img_size = pretrain_img_size
|
||||||
|
self.num_layers = len(depths)
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
self.ape = ape
|
||||||
|
self.patch_norm = patch_norm
|
||||||
|
self.out_indices = out_indices
|
||||||
|
self.frozen_stages = frozen_stages
|
||||||
|
self.dilation = dilation
|
||||||
|
|
||||||
|
# if use_checkpoint:
|
||||||
|
# print("use_checkpoint!!!!!!!!!!!!!!!!!!!!!!!!")
|
||||||
|
|
||||||
|
# split image into non-overlapping patches
|
||||||
|
self.patch_embed = PatchEmbed(
|
||||||
|
patch_size=patch_size,
|
||||||
|
in_chans=in_chans,
|
||||||
|
embed_dim=embed_dim,
|
||||||
|
norm_layer=norm_layer if self.patch_norm else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# absolute position embedding
|
||||||
|
if self.ape:
|
||||||
|
pretrain_img_size = to_2tuple(pretrain_img_size)
|
||||||
|
patch_size = to_2tuple(patch_size)
|
||||||
|
patches_resolution = [
|
||||||
|
pretrain_img_size[0] // patch_size[0],
|
||||||
|
pretrain_img_size[1] // patch_size[1],
|
||||||
|
]
|
||||||
|
|
||||||
|
self.absolute_pos_embed = nn.Parameter(
|
||||||
|
torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1])
|
||||||
|
)
|
||||||
|
trunc_normal_(self.absolute_pos_embed, std=0.02)
|
||||||
|
|
||||||
|
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||||
|
|
||||||
|
# stochastic depth
|
||||||
|
dpr = [
|
||||||
|
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
|
||||||
|
] # stochastic depth decay rule
|
||||||
|
|
||||||
|
# build layers
|
||||||
|
self.layers = nn.ModuleList()
|
||||||
|
# prepare downsample list
|
||||||
|
downsamplelist = [PatchMerging for i in range(self.num_layers)]
|
||||||
|
downsamplelist[-1] = None
|
||||||
|
num_features = [int(embed_dim * 2**i) for i in range(self.num_layers)]
|
||||||
|
if self.dilation:
|
||||||
|
downsamplelist[-2] = None
|
||||||
|
num_features[-1] = int(embed_dim * 2 ** (self.num_layers - 1)) // 2
|
||||||
|
for i_layer in range(self.num_layers):
|
||||||
|
layer = BasicLayer(
|
||||||
|
# dim=int(embed_dim * 2 ** i_layer),
|
||||||
|
dim=num_features[i_layer],
|
||||||
|
depth=depths[i_layer],
|
||||||
|
num_heads=num_heads[i_layer],
|
||||||
|
window_size=window_size,
|
||||||
|
mlp_ratio=mlp_ratio,
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
qk_scale=qk_scale,
|
||||||
|
drop=drop_rate,
|
||||||
|
attn_drop=attn_drop_rate,
|
||||||
|
drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
# downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
|
||||||
|
downsample=downsamplelist[i_layer],
|
||||||
|
use_checkpoint=use_checkpoint,
|
||||||
|
)
|
||||||
|
self.layers.append(layer)
|
||||||
|
|
||||||
|
# num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
|
||||||
|
self.num_features = num_features
|
||||||
|
|
||||||
|
# add a norm layer for each output
|
||||||
|
for i_layer in out_indices:
|
||||||
|
layer = norm_layer(num_features[i_layer])
|
||||||
|
layer_name = f"norm{i_layer}"
|
||||||
|
self.add_module(layer_name, layer)
|
||||||
|
|
||||||
|
self._freeze_stages()
|
||||||
|
|
||||||
|
def _freeze_stages(self):
|
||||||
|
if self.frozen_stages >= 0:
|
||||||
|
self.patch_embed.eval()
|
||||||
|
for param in self.patch_embed.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
if self.frozen_stages >= 1 and self.ape:
|
||||||
|
self.absolute_pos_embed.requires_grad = False
|
||||||
|
|
||||||
|
if self.frozen_stages >= 2:
|
||||||
|
self.pos_drop.eval()
|
||||||
|
for i in range(0, self.frozen_stages - 1):
|
||||||
|
m = self.layers[i]
|
||||||
|
m.eval()
|
||||||
|
for param in m.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
# def init_weights(self, pretrained=None):
|
||||||
|
# """Initialize the weights in backbone.
|
||||||
|
# Args:
|
||||||
|
# pretrained (str, optional): Path to pre-trained weights.
|
||||||
|
# Defaults to None.
|
||||||
|
# """
|
||||||
|
|
||||||
|
# def _init_weights(m):
|
||||||
|
# if isinstance(m, nn.Linear):
|
||||||
|
# trunc_normal_(m.weight, std=.02)
|
||||||
|
# if isinstance(m, nn.Linear) and m.bias is not None:
|
||||||
|
# nn.init.constant_(m.bias, 0)
|
||||||
|
# elif isinstance(m, nn.LayerNorm):
|
||||||
|
# nn.init.constant_(m.bias, 0)
|
||||||
|
# nn.init.constant_(m.weight, 1.0)
|
||||||
|
|
||||||
|
# if isinstance(pretrained, str):
|
||||||
|
# self.apply(_init_weights)
|
||||||
|
# logger = get_root_logger()
|
||||||
|
# load_checkpoint(self, pretrained, strict=False, logger=logger)
|
||||||
|
# elif pretrained is None:
|
||||||
|
# self.apply(_init_weights)
|
||||||
|
# else:
|
||||||
|
# raise TypeError('pretrained must be a str or None')
|
||||||
|
|
||||||
|
def forward_raw(self, x):
|
||||||
|
"""Forward function."""
|
||||||
|
x = self.patch_embed(x)
|
||||||
|
|
||||||
|
Wh, Ww = x.size(2), x.size(3)
|
||||||
|
if self.ape:
|
||||||
|
# interpolate the position embedding to the corresponding size
|
||||||
|
absolute_pos_embed = F.interpolate(
|
||||||
|
self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic"
|
||||||
|
)
|
||||||
|
x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
|
||||||
|
else:
|
||||||
|
x = x.flatten(2).transpose(1, 2)
|
||||||
|
x = self.pos_drop(x)
|
||||||
|
|
||||||
|
outs = []
|
||||||
|
for i in range(self.num_layers):
|
||||||
|
layer = self.layers[i]
|
||||||
|
x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
|
||||||
|
# import ipdb; ipdb.set_trace()
|
||||||
|
|
||||||
|
if i in self.out_indices:
|
||||||
|
norm_layer = getattr(self, f"norm{i}")
|
||||||
|
x_out = norm_layer(x_out)
|
||||||
|
|
||||||
|
out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
|
||||||
|
outs.append(out)
|
||||||
|
# in:
|
||||||
|
# torch.Size([2, 3, 1024, 1024])
|
||||||
|
# outs:
|
||||||
|
# [torch.Size([2, 192, 256, 256]), torch.Size([2, 384, 128, 128]), \
|
||||||
|
# torch.Size([2, 768, 64, 64]), torch.Size([2, 1536, 32, 32])]
|
||||||
|
return tuple(outs)
|
||||||
|
|
||||||
|
def forward(self, tensor_list: NestedTensor):
|
||||||
|
x = tensor_list.tensors
|
||||||
|
|
||||||
|
"""Forward function."""
|
||||||
|
x = self.patch_embed(x)
|
||||||
|
|
||||||
|
Wh, Ww = x.size(2), x.size(3)
|
||||||
|
if self.ape:
|
||||||
|
# interpolate the position embedding to the corresponding size
|
||||||
|
absolute_pos_embed = F.interpolate(
|
||||||
|
self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic"
|
||||||
|
)
|
||||||
|
x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
|
||||||
|
else:
|
||||||
|
x = x.flatten(2).transpose(1, 2)
|
||||||
|
x = self.pos_drop(x)
|
||||||
|
|
||||||
|
outs = []
|
||||||
|
for i in range(self.num_layers):
|
||||||
|
layer = self.layers[i]
|
||||||
|
x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
|
||||||
|
|
||||||
|
if i in self.out_indices:
|
||||||
|
norm_layer = getattr(self, f"norm{i}")
|
||||||
|
x_out = norm_layer(x_out)
|
||||||
|
|
||||||
|
out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
|
||||||
|
outs.append(out)
|
||||||
|
# in:
|
||||||
|
# torch.Size([2, 3, 1024, 1024])
|
||||||
|
# out:
|
||||||
|
# [torch.Size([2, 192, 256, 256]), torch.Size([2, 384, 128, 128]), \
|
||||||
|
# torch.Size([2, 768, 64, 64]), torch.Size([2, 1536, 32, 32])]
|
||||||
|
|
||||||
|
# collect for nesttensors
|
||||||
|
outs_dict = {}
|
||||||
|
for idx, out_i in enumerate(outs):
|
||||||
|
m = tensor_list.mask
|
||||||
|
assert m is not None
|
||||||
|
mask = F.interpolate(m[None].float(), size=out_i.shape[-2:]).to(torch.bool)[0]
|
||||||
|
outs_dict[idx] = NestedTensor(out_i, mask)
|
||||||
|
|
||||||
|
return outs_dict
|
||||||
|
|
||||||
|
def train(self, mode=True):
|
||||||
|
"""Convert the model into training mode while keep layers freezed."""
|
||||||
|
super(SwinTransformer, self).train(mode)
|
||||||
|
self._freeze_stages()
|
||||||
|
|
||||||
|
|
||||||
|
def build_swin_transformer(modelname, pretrain_img_size, **kw):
|
||||||
|
assert modelname in [
|
||||||
|
"swin_T_224_1k",
|
||||||
|
"swin_B_224_22k",
|
||||||
|
"swin_B_384_22k",
|
||||||
|
"swin_L_224_22k",
|
||||||
|
"swin_L_384_22k",
|
||||||
|
]
|
||||||
|
|
||||||
|
model_para_dict = {
|
||||||
|
"swin_T_224_1k": dict(
|
||||||
|
embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=7
|
||||||
|
),
|
||||||
|
"swin_B_224_22k": dict(
|
||||||
|
embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], window_size=7
|
||||||
|
),
|
||||||
|
"swin_B_384_22k": dict(
|
||||||
|
embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], window_size=12
|
||||||
|
),
|
||||||
|
"swin_L_224_22k": dict(
|
||||||
|
embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48], window_size=7
|
||||||
|
),
|
||||||
|
"swin_L_384_22k": dict(
|
||||||
|
embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48], window_size=12
|
||||||
|
),
|
||||||
|
}
|
||||||
|
kw_cgf = model_para_dict[modelname]
|
||||||
|
kw_cgf.update(kw)
|
||||||
|
model = SwinTransformer(pretrain_img_size=pretrain_img_size, **kw_cgf)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
model = build_swin_transformer("swin_L_384_22k", 384, dilation=True)
|
||||||
|
x = torch.rand(2, 3, 1024, 1024)
|
||||||
|
y = model.forward_raw(x)
|
||||||
|
import ipdb
|
||||||
|
|
||||||
|
ipdb.set_trace()
|
||||||
|
x = torch.rand(2, 3, 384, 384)
|
||||||
|
y = model.forward_raw(x)
|
@ -0,0 +1,273 @@
|
|||||||
|
# ------------------------------------------------------------------------
|
||||||
|
# Grounding DINO
|
||||||
|
# url: https://github.com/IDEA-Research/GroundingDINO
|
||||||
|
# Copyright (c) 2023 IDEA. All Rights Reserved.
|
||||||
|
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
||||||
|
# ------------------------------------------------------------------------
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch.utils.checkpoint as checkpoint
|
||||||
|
from torch import Tensor, nn
|
||||||
|
from torchvision.ops.boxes import nms
|
||||||
|
from transformers import BertConfig, BertModel, BertPreTrainedModel
|
||||||
|
from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
|
||||||
|
|
||||||
|
|
||||||
|
class BertModelWarper(nn.Module):
|
||||||
|
def __init__(self, bert_model):
|
||||||
|
super().__init__()
|
||||||
|
# self.bert = bert_modelc
|
||||||
|
|
||||||
|
self.config = bert_model.config
|
||||||
|
self.embeddings = bert_model.embeddings
|
||||||
|
self.encoder = bert_model.encoder
|
||||||
|
self.pooler = bert_model.pooler
|
||||||
|
|
||||||
|
self.get_extended_attention_mask = bert_model.get_extended_attention_mask
|
||||||
|
self.invert_attention_mask = bert_model.invert_attention_mask
|
||||||
|
self.get_head_mask = bert_model.get_head_mask
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids=None,
|
||||||
|
attention_mask=None,
|
||||||
|
token_type_ids=None,
|
||||||
|
position_ids=None,
|
||||||
|
head_mask=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
encoder_attention_mask=None,
|
||||||
|
past_key_values=None,
|
||||||
|
use_cache=None,
|
||||||
|
output_attentions=None,
|
||||||
|
output_hidden_states=None,
|
||||||
|
return_dict=None,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||||
|
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
||||||
|
the model is configured as a decoder.
|
||||||
|
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||||
|
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
||||||
|
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 for tokens that are **not masked**,
|
||||||
|
- 0 for tokens that are **masked**.
|
||||||
|
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||||
|
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
||||||
|
|
||||||
|
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
||||||
|
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
||||||
|
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
||||||
|
use_cache (:obj:`bool`, `optional`):
|
||||||
|
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
||||||
|
decoding (see :obj:`past_key_values`).
|
||||||
|
"""
|
||||||
|
output_attentions = (
|
||||||
|
output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
)
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states
|
||||||
|
if output_hidden_states is not None
|
||||||
|
else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
if self.config.is_decoder:
|
||||||
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
else:
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||||
|
elif input_ids is not None:
|
||||||
|
input_shape = input_ids.size()
|
||||||
|
batch_size, seq_length = input_shape
|
||||||
|
elif inputs_embeds is not None:
|
||||||
|
input_shape = inputs_embeds.size()[:-1]
|
||||||
|
batch_size, seq_length = input_shape
|
||||||
|
else:
|
||||||
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||||
|
|
||||||
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
|
|
||||||
|
# past_key_values_length
|
||||||
|
past_key_values_length = (
|
||||||
|
past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
||||||
|
)
|
||||||
|
|
||||||
|
if attention_mask is None:
|
||||||
|
attention_mask = torch.ones(
|
||||||
|
((batch_size, seq_length + past_key_values_length)), device=device
|
||||||
|
)
|
||||||
|
if token_type_ids is None:
|
||||||
|
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
||||||
|
|
||||||
|
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||||
|
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||||
|
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
|
||||||
|
attention_mask, input_shape, device
|
||||||
|
)
|
||||||
|
|
||||||
|
# If a 2D or 3D attention mask is provided for the cross-attention
|
||||||
|
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||||
|
if self.config.is_decoder and encoder_hidden_states is not None:
|
||||||
|
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
||||||
|
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
||||||
|
if encoder_attention_mask is None:
|
||||||
|
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
||||||
|
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
||||||
|
else:
|
||||||
|
encoder_extended_attention_mask = None
|
||||||
|
# if os.environ.get('IPDB_SHILONG_DEBUG', None) == 'INFO':
|
||||||
|
# import ipdb; ipdb.set_trace()
|
||||||
|
|
||||||
|
# Prepare head mask if needed
|
||||||
|
# 1.0 in head_mask indicate we keep the head
|
||||||
|
# attention_probs has shape bsz x n_heads x N x N
|
||||||
|
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
||||||
|
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||||
|
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||||
|
|
||||||
|
embedding_output = self.embeddings(
|
||||||
|
input_ids=input_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
past_key_values_length=past_key_values_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
encoder_outputs = self.encoder(
|
||||||
|
embedding_output,
|
||||||
|
attention_mask=extended_attention_mask,
|
||||||
|
head_mask=head_mask,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_extended_attention_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
sequence_output = encoder_outputs[0]
|
||||||
|
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
||||||
|
|
||||||
|
return BaseModelOutputWithPoolingAndCrossAttentions(
|
||||||
|
last_hidden_state=sequence_output,
|
||||||
|
pooler_output=pooled_output,
|
||||||
|
past_key_values=encoder_outputs.past_key_values,
|
||||||
|
hidden_states=encoder_outputs.hidden_states,
|
||||||
|
attentions=encoder_outputs.attentions,
|
||||||
|
cross_attentions=encoder_outputs.cross_attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TextEncoderShell(nn.Module):
|
||||||
|
def __init__(self, text_encoder):
|
||||||
|
super().__init__()
|
||||||
|
self.text_encoder = text_encoder
|
||||||
|
self.config = self.text_encoder.config
|
||||||
|
|
||||||
|
def forward(self, **kw):
|
||||||
|
# feed into text encoder
|
||||||
|
return self.text_encoder(**kw)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_masks_with_special_tokens(tokenized, special_tokens_list, tokenizer):
|
||||||
|
"""Generate attention mask between each pair of special tokens
|
||||||
|
Args:
|
||||||
|
input_ids (torch.Tensor): input ids. Shape: [bs, num_token]
|
||||||
|
special_tokens_mask (list): special tokens mask.
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: attention mask between each special tokens.
|
||||||
|
"""
|
||||||
|
input_ids = tokenized["input_ids"]
|
||||||
|
bs, num_token = input_ids.shape
|
||||||
|
# special_tokens_mask: bs, num_token. 1 for special tokens. 0 for normal tokens
|
||||||
|
special_tokens_mask = torch.zeros((bs, num_token), device=input_ids.device).bool()
|
||||||
|
for special_token in special_tokens_list:
|
||||||
|
special_tokens_mask |= input_ids == special_token
|
||||||
|
|
||||||
|
# idxs: each row is a list of indices of special tokens
|
||||||
|
idxs = torch.nonzero(special_tokens_mask)
|
||||||
|
|
||||||
|
# generate attention mask and positional ids
|
||||||
|
attention_mask = (
|
||||||
|
torch.eye(num_token, device=input_ids.device).bool().unsqueeze(0).repeat(bs, 1, 1)
|
||||||
|
)
|
||||||
|
position_ids = torch.zeros((bs, num_token), device=input_ids.device)
|
||||||
|
previous_col = 0
|
||||||
|
for i in range(idxs.shape[0]):
|
||||||
|
row, col = idxs[i]
|
||||||
|
if (col == 0) or (col == num_token - 1):
|
||||||
|
attention_mask[row, col, col] = True
|
||||||
|
position_ids[row, col] = 0
|
||||||
|
else:
|
||||||
|
attention_mask[row, previous_col + 1 : col + 1, previous_col + 1 : col + 1] = True
|
||||||
|
position_ids[row, previous_col + 1 : col + 1] = torch.arange(
|
||||||
|
0, col - previous_col, device=input_ids.device
|
||||||
|
)
|
||||||
|
|
||||||
|
previous_col = col
|
||||||
|
|
||||||
|
# # padding mask
|
||||||
|
# padding_mask = tokenized['attention_mask']
|
||||||
|
# attention_mask = attention_mask & padding_mask.unsqueeze(1).bool() & padding_mask.unsqueeze(2).bool()
|
||||||
|
|
||||||
|
return attention_mask, position_ids.to(torch.long)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_masks_with_special_tokens_and_transfer_map(tokenized, special_tokens_list, tokenizer):
|
||||||
|
"""Generate attention mask between each pair of special tokens
|
||||||
|
Args:
|
||||||
|
input_ids (torch.Tensor): input ids. Shape: [bs, num_token]
|
||||||
|
special_tokens_mask (list): special tokens mask.
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: attention mask between each special tokens.
|
||||||
|
"""
|
||||||
|
input_ids = tokenized["input_ids"]
|
||||||
|
bs, num_token = input_ids.shape
|
||||||
|
# special_tokens_mask: bs, num_token. 1 for special tokens. 0 for normal tokens
|
||||||
|
special_tokens_mask = torch.zeros((bs, num_token), device=input_ids.device).bool()
|
||||||
|
for special_token in special_tokens_list:
|
||||||
|
special_tokens_mask |= input_ids == special_token
|
||||||
|
|
||||||
|
# idxs: each row is a list of indices of special tokens
|
||||||
|
idxs = torch.nonzero(special_tokens_mask)
|
||||||
|
|
||||||
|
# generate attention mask and positional ids
|
||||||
|
attention_mask = (
|
||||||
|
torch.eye(num_token, device=input_ids.device).bool().unsqueeze(0).repeat(bs, 1, 1)
|
||||||
|
)
|
||||||
|
position_ids = torch.zeros((bs, num_token), device=input_ids.device)
|
||||||
|
cate_to_token_mask_list = [[] for _ in range(bs)]
|
||||||
|
previous_col = 0
|
||||||
|
for i in range(idxs.shape[0]):
|
||||||
|
row, col = idxs[i]
|
||||||
|
if (col == 0) or (col == num_token - 1):
|
||||||
|
attention_mask[row, col, col] = True
|
||||||
|
position_ids[row, col] = 0
|
||||||
|
else:
|
||||||
|
attention_mask[row, previous_col + 1 : col + 1, previous_col + 1 : col + 1] = True
|
||||||
|
position_ids[row, previous_col + 1 : col + 1] = torch.arange(
|
||||||
|
0, col - previous_col, device=input_ids.device
|
||||||
|
)
|
||||||
|
c2t_maski = torch.zeros((num_token), device=input_ids.device).bool()
|
||||||
|
c2t_maski[previous_col + 1 : col] = True
|
||||||
|
cate_to_token_mask_list[row].append(c2t_maski)
|
||||||
|
previous_col = col
|
||||||
|
|
||||||
|
cate_to_token_mask_list = [
|
||||||
|
torch.stack(cate_to_token_mask_listi, dim=0)
|
||||||
|
for cate_to_token_mask_listi in cate_to_token_mask_list
|
||||||
|
]
|
||||||
|
|
||||||
|
# # padding mask
|
||||||
|
# padding_mask = tokenized['attention_mask']
|
||||||
|
# attention_mask = attention_mask & padding_mask.unsqueeze(1).bool() & padding_mask.unsqueeze(2).bool()
|
||||||
|
|
||||||
|
return attention_mask, position_ids.to(torch.long), cate_to_token_mask_list
|
@ -0,0 +1,64 @@
|
|||||||
|
/*!
|
||||||
|
**************************************************************************************************
|
||||||
|
* Deformable DETR
|
||||||
|
* Copyright (c) 2020 SenseTime. All Rights Reserved.
|
||||||
|
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
||||||
|
**************************************************************************************************
|
||||||
|
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
||||||
|
**************************************************************************************************
|
||||||
|
*/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "ms_deform_attn_cpu.h"
|
||||||
|
|
||||||
|
#ifdef WITH_CUDA
|
||||||
|
#include "ms_deform_attn_cuda.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace groundingdino {
|
||||||
|
|
||||||
|
at::Tensor
|
||||||
|
ms_deform_attn_forward(
|
||||||
|
const at::Tensor &value,
|
||||||
|
const at::Tensor &spatial_shapes,
|
||||||
|
const at::Tensor &level_start_index,
|
||||||
|
const at::Tensor &sampling_loc,
|
||||||
|
const at::Tensor &attn_weight,
|
||||||
|
const int im2col_step)
|
||||||
|
{
|
||||||
|
if (value.type().is_cuda())
|
||||||
|
{
|
||||||
|
#ifdef WITH_CUDA
|
||||||
|
return ms_deform_attn_cuda_forward(
|
||||||
|
value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step);
|
||||||
|
#else
|
||||||
|
AT_ERROR("Not compiled with GPU support");
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
AT_ERROR("Not implemented on the CPU");
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<at::Tensor>
|
||||||
|
ms_deform_attn_backward(
|
||||||
|
const at::Tensor &value,
|
||||||
|
const at::Tensor &spatial_shapes,
|
||||||
|
const at::Tensor &level_start_index,
|
||||||
|
const at::Tensor &sampling_loc,
|
||||||
|
const at::Tensor &attn_weight,
|
||||||
|
const at::Tensor &grad_output,
|
||||||
|
const int im2col_step)
|
||||||
|
{
|
||||||
|
if (value.type().is_cuda())
|
||||||
|
{
|
||||||
|
#ifdef WITH_CUDA
|
||||||
|
return ms_deform_attn_cuda_backward(
|
||||||
|
value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step);
|
||||||
|
#else
|
||||||
|
AT_ERROR("Not compiled with GPU support");
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
AT_ERROR("Not implemented on the CPU");
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace groundingdino
|
@ -0,0 +1,43 @@
|
|||||||
|
/*!
|
||||||
|
**************************************************************************************************
|
||||||
|
* Deformable DETR
|
||||||
|
* Copyright (c) 2020 SenseTime. All Rights Reserved.
|
||||||
|
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
||||||
|
**************************************************************************************************
|
||||||
|
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
||||||
|
**************************************************************************************************
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include <ATen/ATen.h>
|
||||||
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
|
||||||
|
namespace groundingdino {
|
||||||
|
|
||||||
|
at::Tensor
|
||||||
|
ms_deform_attn_cpu_forward(
|
||||||
|
const at::Tensor &value,
|
||||||
|
const at::Tensor &spatial_shapes,
|
||||||
|
const at::Tensor &level_start_index,
|
||||||
|
const at::Tensor &sampling_loc,
|
||||||
|
const at::Tensor &attn_weight,
|
||||||
|
const int im2col_step)
|
||||||
|
{
|
||||||
|
AT_ERROR("Not implement on cpu");
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<at::Tensor>
|
||||||
|
ms_deform_attn_cpu_backward(
|
||||||
|
const at::Tensor &value,
|
||||||
|
const at::Tensor &spatial_shapes,
|
||||||
|
const at::Tensor &level_start_index,
|
||||||
|
const at::Tensor &sampling_loc,
|
||||||
|
const at::Tensor &attn_weight,
|
||||||
|
const at::Tensor &grad_output,
|
||||||
|
const int im2col_step)
|
||||||
|
{
|
||||||
|
AT_ERROR("Not implement on cpu");
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace groundingdino
|
@ -0,0 +1,35 @@
|
|||||||
|
/*!
|
||||||
|
**************************************************************************************************
|
||||||
|
* Deformable DETR
|
||||||
|
* Copyright (c) 2020 SenseTime. All Rights Reserved.
|
||||||
|
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
||||||
|
**************************************************************************************************
|
||||||
|
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
||||||
|
**************************************************************************************************
|
||||||
|
*/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
#include <torch/extension.h>
|
||||||
|
|
||||||
|
namespace groundingdino {
|
||||||
|
|
||||||
|
at::Tensor
|
||||||
|
ms_deform_attn_cpu_forward(
|
||||||
|
const at::Tensor &value,
|
||||||
|
const at::Tensor &spatial_shapes,
|
||||||
|
const at::Tensor &level_start_index,
|
||||||
|
const at::Tensor &sampling_loc,
|
||||||
|
const at::Tensor &attn_weight,
|
||||||
|
const int im2col_step);
|
||||||
|
|
||||||
|
std::vector<at::Tensor>
|
||||||
|
ms_deform_attn_cpu_backward(
|
||||||
|
const at::Tensor &value,
|
||||||
|
const at::Tensor &spatial_shapes,
|
||||||
|
const at::Tensor &level_start_index,
|
||||||
|
const at::Tensor &sampling_loc,
|
||||||
|
const at::Tensor &attn_weight,
|
||||||
|
const at::Tensor &grad_output,
|
||||||
|
const int im2col_step);
|
||||||
|
|
||||||
|
} // namespace groundingdino
|
@ -0,0 +1,156 @@
|
|||||||
|
/*!
|
||||||
|
**************************************************************************************************
|
||||||
|
* Deformable DETR
|
||||||
|
* Copyright (c) 2020 SenseTime. All Rights Reserved.
|
||||||
|
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
||||||
|
**************************************************************************************************
|
||||||
|
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
||||||
|
**************************************************************************************************
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include "ms_deform_im2col_cuda.cuh"
|
||||||
|
|
||||||
|
#include <ATen/ATen.h>
|
||||||
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
#include <cuda.h>
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
|
||||||
|
namespace groundingdino {
|
||||||
|
|
||||||
|
at::Tensor ms_deform_attn_cuda_forward(
|
||||||
|
const at::Tensor &value,
|
||||||
|
const at::Tensor &spatial_shapes,
|
||||||
|
const at::Tensor &level_start_index,
|
||||||
|
const at::Tensor &sampling_loc,
|
||||||
|
const at::Tensor &attn_weight,
|
||||||
|
const int im2col_step)
|
||||||
|
{
|
||||||
|
AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
|
||||||
|
AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
|
||||||
|
AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
|
||||||
|
AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
|
||||||
|
AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
|
||||||
|
|
||||||
|
AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
|
||||||
|
AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
|
||||||
|
AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
|
||||||
|
AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
|
||||||
|
AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
|
||||||
|
|
||||||
|
const int batch = value.size(0);
|
||||||
|
const int spatial_size = value.size(1);
|
||||||
|
const int num_heads = value.size(2);
|
||||||
|
const int channels = value.size(3);
|
||||||
|
|
||||||
|
const int num_levels = spatial_shapes.size(0);
|
||||||
|
|
||||||
|
const int num_query = sampling_loc.size(1);
|
||||||
|
const int num_point = sampling_loc.size(4);
|
||||||
|
|
||||||
|
const int im2col_step_ = std::min(batch, im2col_step);
|
||||||
|
|
||||||
|
AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
|
||||||
|
|
||||||
|
auto output = at::zeros({batch, num_query, num_heads, channels}, value.options());
|
||||||
|
|
||||||
|
const int batch_n = im2col_step_;
|
||||||
|
auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
|
||||||
|
auto per_value_size = spatial_size * num_heads * channels;
|
||||||
|
auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
|
||||||
|
auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
|
||||||
|
for (int n = 0; n < batch/im2col_step_; ++n)
|
||||||
|
{
|
||||||
|
auto columns = output_n.select(0, n);
|
||||||
|
AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] {
|
||||||
|
ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
|
||||||
|
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
|
||||||
|
spatial_shapes.data<int64_t>(),
|
||||||
|
level_start_index.data<int64_t>(),
|
||||||
|
sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
|
||||||
|
attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
|
||||||
|
batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
|
||||||
|
columns.data<scalar_t>());
|
||||||
|
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
output = output.view({batch, num_query, num_heads*channels});
|
||||||
|
|
||||||
|
return output;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
std::vector<at::Tensor> ms_deform_attn_cuda_backward(
|
||||||
|
const at::Tensor &value,
|
||||||
|
const at::Tensor &spatial_shapes,
|
||||||
|
const at::Tensor &level_start_index,
|
||||||
|
const at::Tensor &sampling_loc,
|
||||||
|
const at::Tensor &attn_weight,
|
||||||
|
const at::Tensor &grad_output,
|
||||||
|
const int im2col_step)
|
||||||
|
{
|
||||||
|
|
||||||
|
AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
|
||||||
|
AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
|
||||||
|
AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
|
||||||
|
AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
|
||||||
|
AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
|
||||||
|
AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
|
||||||
|
|
||||||
|
AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
|
||||||
|
AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
|
||||||
|
AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
|
||||||
|
AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
|
||||||
|
AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
|
||||||
|
AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor");
|
||||||
|
|
||||||
|
const int batch = value.size(0);
|
||||||
|
const int spatial_size = value.size(1);
|
||||||
|
const int num_heads = value.size(2);
|
||||||
|
const int channels = value.size(3);
|
||||||
|
|
||||||
|
const int num_levels = spatial_shapes.size(0);
|
||||||
|
|
||||||
|
const int num_query = sampling_loc.size(1);
|
||||||
|
const int num_point = sampling_loc.size(4);
|
||||||
|
|
||||||
|
const int im2col_step_ = std::min(batch, im2col_step);
|
||||||
|
|
||||||
|
AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
|
||||||
|
|
||||||
|
auto grad_value = at::zeros_like(value);
|
||||||
|
auto grad_sampling_loc = at::zeros_like(sampling_loc);
|
||||||
|
auto grad_attn_weight = at::zeros_like(attn_weight);
|
||||||
|
|
||||||
|
const int batch_n = im2col_step_;
|
||||||
|
auto per_value_size = spatial_size * num_heads * channels;
|
||||||
|
auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
|
||||||
|
auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
|
||||||
|
auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
|
||||||
|
|
||||||
|
for (int n = 0; n < batch/im2col_step_; ++n)
|
||||||
|
{
|
||||||
|
auto grad_output_g = grad_output_n.select(0, n);
|
||||||
|
AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] {
|
||||||
|
ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
|
||||||
|
grad_output_g.data<scalar_t>(),
|
||||||
|
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
|
||||||
|
spatial_shapes.data<int64_t>(),
|
||||||
|
level_start_index.data<int64_t>(),
|
||||||
|
sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
|
||||||
|
attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
|
||||||
|
batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
|
||||||
|
grad_value.data<scalar_t>() + n * im2col_step_ * per_value_size,
|
||||||
|
grad_sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
|
||||||
|
grad_attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size);
|
||||||
|
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
grad_value, grad_sampling_loc, grad_attn_weight
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace groundingdino
|
@ -0,0 +1,33 @@
|
|||||||
|
/*!
|
||||||
|
**************************************************************************************************
|
||||||
|
* Deformable DETR
|
||||||
|
* Copyright (c) 2020 SenseTime. All Rights Reserved.
|
||||||
|
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
||||||
|
**************************************************************************************************
|
||||||
|
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
||||||
|
**************************************************************************************************
|
||||||
|
*/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
#include <torch/extension.h>
|
||||||
|
|
||||||
|
namespace groundingdino {
|
||||||
|
|
||||||
|
at::Tensor ms_deform_attn_cuda_forward(
|
||||||
|
const at::Tensor &value,
|
||||||
|
const at::Tensor &spatial_shapes,
|
||||||
|
const at::Tensor &level_start_index,
|
||||||
|
const at::Tensor &sampling_loc,
|
||||||
|
const at::Tensor &attn_weight,
|
||||||
|
const int im2col_step);
|
||||||
|
|
||||||
|
std::vector<at::Tensor> ms_deform_attn_cuda_backward(
|
||||||
|
const at::Tensor &value,
|
||||||
|
const at::Tensor &spatial_shapes,
|
||||||
|
const at::Tensor &level_start_index,
|
||||||
|
const at::Tensor &sampling_loc,
|
||||||
|
const at::Tensor &attn_weight,
|
||||||
|
const at::Tensor &grad_output,
|
||||||
|
const int im2col_step);
|
||||||
|
|
||||||
|
} // namespace groundingdino
|
@ -0,0 +1,7 @@
|
|||||||
|
#include <cuda_runtime_api.h>
|
||||||
|
|
||||||
|
namespace groundingdino {
|
||||||
|
int get_cudart_version() {
|
||||||
|
return CUDART_VERSION;
|
||||||
|
}
|
||||||
|
} // namespace groundingdino
|
@ -0,0 +1,58 @@
|
|||||||
|
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||||
|
|
||||||
|
#include "MsDeformAttn/ms_deform_attn.h"
|
||||||
|
|
||||||
|
namespace groundingdino {
|
||||||
|
|
||||||
|
#ifdef WITH_CUDA
|
||||||
|
extern int get_cudart_version();
|
||||||
|
#endif
|
||||||
|
|
||||||
|
std::string get_cuda_version() {
|
||||||
|
#ifdef WITH_CUDA
|
||||||
|
std::ostringstream oss;
|
||||||
|
|
||||||
|
// copied from
|
||||||
|
// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/cuda/detail/CUDAHooks.cpp#L231
|
||||||
|
auto printCudaStyleVersion = [&](int v) {
|
||||||
|
oss << (v / 1000) << "." << (v / 10 % 100);
|
||||||
|
if (v % 10 != 0) {
|
||||||
|
oss << "." << (v % 10);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
printCudaStyleVersion(get_cudart_version());
|
||||||
|
return oss.str();
|
||||||
|
#else
|
||||||
|
return std::string("not available");
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
// similar to
|
||||||
|
// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Version.cpp
|
||||||
|
std::string get_compiler_version() {
|
||||||
|
std::ostringstream ss;
|
||||||
|
#if defined(__GNUC__)
|
||||||
|
#ifndef __clang__
|
||||||
|
{ ss << "GCC " << __GNUC__ << "." << __GNUC_MINOR__; }
|
||||||
|
#endif
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if defined(__clang_major__)
|
||||||
|
{
|
||||||
|
ss << "clang " << __clang_major__ << "." << __clang_minor__ << "."
|
||||||
|
<< __clang_patchlevel__;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if defined(_MSC_VER)
|
||||||
|
{ ss << "MSVC " << _MSC_FULL_VER; }
|
||||||
|
#endif
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||||
|
m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward");
|
||||||
|
m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward");
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace groundingdino
|
@ -0,0 +1,297 @@
|
|||||||
|
# ------------------------------------------------------------------------
|
||||||
|
# Grounding DINO
|
||||||
|
# url: https://github.com/IDEA-Research/GroundingDINO
|
||||||
|
# Copyright (c) 2023 IDEA. All Rights Reserved.
|
||||||
|
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
||||||
|
# ------------------------------------------------------------------------
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from timm.models.layers import DropPath
|
||||||
|
|
||||||
|
|
||||||
|
class FeatureResizer(nn.Module):
|
||||||
|
"""
|
||||||
|
This class takes as input a set of embeddings of dimension C1 and outputs a set of
|
||||||
|
embedding of dimension C2, after a linear transformation, dropout and normalization (LN).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, input_feat_size, output_feat_size, dropout, do_ln=True):
|
||||||
|
super().__init__()
|
||||||
|
self.do_ln = do_ln
|
||||||
|
# Object feature encoding
|
||||||
|
self.fc = nn.Linear(input_feat_size, output_feat_size, bias=True)
|
||||||
|
self.layer_norm = nn.LayerNorm(output_feat_size, eps=1e-12)
|
||||||
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
|
def forward(self, encoder_features):
|
||||||
|
x = self.fc(encoder_features)
|
||||||
|
if self.do_ln:
|
||||||
|
x = self.layer_norm(x)
|
||||||
|
output = self.dropout(x)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def l1norm(X, dim, eps=1e-8):
|
||||||
|
"""L1-normalize columns of X"""
|
||||||
|
norm = torch.abs(X).sum(dim=dim, keepdim=True) + eps
|
||||||
|
X = torch.div(X, norm)
|
||||||
|
return X
|
||||||
|
|
||||||
|
|
||||||
|
def l2norm(X, dim, eps=1e-8):
|
||||||
|
"""L2-normalize columns of X"""
|
||||||
|
norm = torch.pow(X, 2).sum(dim=dim, keepdim=True).sqrt() + eps
|
||||||
|
X = torch.div(X, norm)
|
||||||
|
return X
|
||||||
|
|
||||||
|
|
||||||
|
def func_attention(query, context, smooth=1, raw_feature_norm="softmax", eps=1e-8):
|
||||||
|
"""
|
||||||
|
query: (n_context, queryL, d)
|
||||||
|
context: (n_context, sourceL, d)
|
||||||
|
"""
|
||||||
|
batch_size_q, queryL = query.size(0), query.size(1)
|
||||||
|
batch_size, sourceL = context.size(0), context.size(1)
|
||||||
|
|
||||||
|
# Get attention
|
||||||
|
# --> (batch, d, queryL)
|
||||||
|
queryT = torch.transpose(query, 1, 2)
|
||||||
|
|
||||||
|
# (batch, sourceL, d)(batch, d, queryL)
|
||||||
|
# --> (batch, sourceL, queryL)
|
||||||
|
attn = torch.bmm(context, queryT)
|
||||||
|
if raw_feature_norm == "softmax":
|
||||||
|
# --> (batch*sourceL, queryL)
|
||||||
|
attn = attn.view(batch_size * sourceL, queryL)
|
||||||
|
attn = nn.Softmax()(attn)
|
||||||
|
# --> (batch, sourceL, queryL)
|
||||||
|
attn = attn.view(batch_size, sourceL, queryL)
|
||||||
|
elif raw_feature_norm == "l2norm":
|
||||||
|
attn = l2norm(attn, 2)
|
||||||
|
elif raw_feature_norm == "clipped_l2norm":
|
||||||
|
attn = nn.LeakyReLU(0.1)(attn)
|
||||||
|
attn = l2norm(attn, 2)
|
||||||
|
else:
|
||||||
|
raise ValueError("unknown first norm type:", raw_feature_norm)
|
||||||
|
# --> (batch, queryL, sourceL)
|
||||||
|
attn = torch.transpose(attn, 1, 2).contiguous()
|
||||||
|
# --> (batch*queryL, sourceL)
|
||||||
|
attn = attn.view(batch_size * queryL, sourceL)
|
||||||
|
attn = nn.Softmax()(attn * smooth)
|
||||||
|
# --> (batch, queryL, sourceL)
|
||||||
|
attn = attn.view(batch_size, queryL, sourceL)
|
||||||
|
# --> (batch, sourceL, queryL)
|
||||||
|
attnT = torch.transpose(attn, 1, 2).contiguous()
|
||||||
|
|
||||||
|
# --> (batch, d, sourceL)
|
||||||
|
contextT = torch.transpose(context, 1, 2)
|
||||||
|
# (batch x d x sourceL)(batch x sourceL x queryL)
|
||||||
|
# --> (batch, d, queryL)
|
||||||
|
weightedContext = torch.bmm(contextT, attnT)
|
||||||
|
# --> (batch, queryL, d)
|
||||||
|
weightedContext = torch.transpose(weightedContext, 1, 2)
|
||||||
|
|
||||||
|
return weightedContext, attnT
|
||||||
|
|
||||||
|
|
||||||
|
class BiMultiHeadAttention(nn.Module):
|
||||||
|
def __init__(self, v_dim, l_dim, embed_dim, num_heads, dropout=0.1, cfg=None):
|
||||||
|
super(BiMultiHeadAttention, self).__init__()
|
||||||
|
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = embed_dim // num_heads
|
||||||
|
self.v_dim = v_dim
|
||||||
|
self.l_dim = l_dim
|
||||||
|
|
||||||
|
assert (
|
||||||
|
self.head_dim * self.num_heads == self.embed_dim
|
||||||
|
), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
|
||||||
|
self.scale = self.head_dim ** (-0.5)
|
||||||
|
self.dropout = dropout
|
||||||
|
|
||||||
|
self.v_proj = nn.Linear(self.v_dim, self.embed_dim)
|
||||||
|
self.l_proj = nn.Linear(self.l_dim, self.embed_dim)
|
||||||
|
self.values_v_proj = nn.Linear(self.v_dim, self.embed_dim)
|
||||||
|
self.values_l_proj = nn.Linear(self.l_dim, self.embed_dim)
|
||||||
|
|
||||||
|
self.out_v_proj = nn.Linear(self.embed_dim, self.v_dim)
|
||||||
|
self.out_l_proj = nn.Linear(self.embed_dim, self.l_dim)
|
||||||
|
|
||||||
|
self.stable_softmax_2d = True
|
||||||
|
self.clamp_min_for_underflow = True
|
||||||
|
self.clamp_max_for_overflow = True
|
||||||
|
|
||||||
|
self._reset_parameters()
|
||||||
|
|
||||||
|
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||||
|
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
||||||
|
|
||||||
|
def _reset_parameters(self):
|
||||||
|
nn.init.xavier_uniform_(self.v_proj.weight)
|
||||||
|
self.v_proj.bias.data.fill_(0)
|
||||||
|
nn.init.xavier_uniform_(self.l_proj.weight)
|
||||||
|
self.l_proj.bias.data.fill_(0)
|
||||||
|
nn.init.xavier_uniform_(self.values_v_proj.weight)
|
||||||
|
self.values_v_proj.bias.data.fill_(0)
|
||||||
|
nn.init.xavier_uniform_(self.values_l_proj.weight)
|
||||||
|
self.values_l_proj.bias.data.fill_(0)
|
||||||
|
nn.init.xavier_uniform_(self.out_v_proj.weight)
|
||||||
|
self.out_v_proj.bias.data.fill_(0)
|
||||||
|
nn.init.xavier_uniform_(self.out_l_proj.weight)
|
||||||
|
self.out_l_proj.bias.data.fill_(0)
|
||||||
|
|
||||||
|
def forward(self, v, l, attention_mask_v=None, attention_mask_l=None):
|
||||||
|
"""_summary_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
v (_type_): bs, n_img, dim
|
||||||
|
l (_type_): bs, n_text, dim
|
||||||
|
attention_mask_v (_type_, optional): _description_. bs, n_img
|
||||||
|
attention_mask_l (_type_, optional): _description_. bs, n_text
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
_type_: _description_
|
||||||
|
"""
|
||||||
|
# if os.environ.get('IPDB_SHILONG_DEBUG', None) == 'INFO':
|
||||||
|
# import ipdb; ipdb.set_trace()
|
||||||
|
bsz, tgt_len, _ = v.size()
|
||||||
|
|
||||||
|
query_states = self.v_proj(v) * self.scale
|
||||||
|
key_states = self._shape(self.l_proj(l), -1, bsz)
|
||||||
|
value_v_states = self._shape(self.values_v_proj(v), -1, bsz)
|
||||||
|
value_l_states = self._shape(self.values_l_proj(l), -1, bsz)
|
||||||
|
|
||||||
|
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
|
||||||
|
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
||||||
|
key_states = key_states.view(*proj_shape)
|
||||||
|
value_v_states = value_v_states.view(*proj_shape)
|
||||||
|
value_l_states = value_l_states.view(*proj_shape)
|
||||||
|
|
||||||
|
src_len = key_states.size(1)
|
||||||
|
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) # bs*nhead, nimg, ntxt
|
||||||
|
|
||||||
|
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
||||||
|
raise ValueError(
|
||||||
|
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.stable_softmax_2d:
|
||||||
|
attn_weights = attn_weights - attn_weights.max()
|
||||||
|
|
||||||
|
if self.clamp_min_for_underflow:
|
||||||
|
attn_weights = torch.clamp(
|
||||||
|
attn_weights, min=-50000
|
||||||
|
) # Do not increase -50000, data type half has quite limited range
|
||||||
|
if self.clamp_max_for_overflow:
|
||||||
|
attn_weights = torch.clamp(
|
||||||
|
attn_weights, max=50000
|
||||||
|
) # Do not increase 50000, data type half has quite limited range
|
||||||
|
|
||||||
|
attn_weights_T = attn_weights.transpose(1, 2)
|
||||||
|
attn_weights_l = attn_weights_T - torch.max(attn_weights_T, dim=-1, keepdim=True)[0]
|
||||||
|
if self.clamp_min_for_underflow:
|
||||||
|
attn_weights_l = torch.clamp(
|
||||||
|
attn_weights_l, min=-50000
|
||||||
|
) # Do not increase -50000, data type half has quite limited range
|
||||||
|
if self.clamp_max_for_overflow:
|
||||||
|
attn_weights_l = torch.clamp(
|
||||||
|
attn_weights_l, max=50000
|
||||||
|
) # Do not increase 50000, data type half has quite limited range
|
||||||
|
|
||||||
|
# mask vison for language
|
||||||
|
if attention_mask_v is not None:
|
||||||
|
attention_mask_v = (
|
||||||
|
attention_mask_v[:, None, None, :].repeat(1, self.num_heads, 1, 1).flatten(0, 1)
|
||||||
|
)
|
||||||
|
attn_weights_l.masked_fill_(attention_mask_v, float("-inf"))
|
||||||
|
|
||||||
|
attn_weights_l = attn_weights_l.softmax(dim=-1)
|
||||||
|
|
||||||
|
# mask language for vision
|
||||||
|
if attention_mask_l is not None:
|
||||||
|
attention_mask_l = (
|
||||||
|
attention_mask_l[:, None, None, :].repeat(1, self.num_heads, 1, 1).flatten(0, 1)
|
||||||
|
)
|
||||||
|
attn_weights.masked_fill_(attention_mask_l, float("-inf"))
|
||||||
|
attn_weights_v = attn_weights.softmax(dim=-1)
|
||||||
|
|
||||||
|
attn_probs_v = F.dropout(attn_weights_v, p=self.dropout, training=self.training)
|
||||||
|
attn_probs_l = F.dropout(attn_weights_l, p=self.dropout, training=self.training)
|
||||||
|
|
||||||
|
attn_output_v = torch.bmm(attn_probs_v, value_l_states)
|
||||||
|
attn_output_l = torch.bmm(attn_probs_l, value_v_states)
|
||||||
|
|
||||||
|
if attn_output_v.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
||||||
|
raise ValueError(
|
||||||
|
f"`attn_output_v` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output_v.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if attn_output_l.size() != (bsz * self.num_heads, src_len, self.head_dim):
|
||||||
|
raise ValueError(
|
||||||
|
f"`attn_output_l` should be of size {(bsz, self.num_heads, src_len, self.head_dim)}, but is {attn_output_l.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output_v = attn_output_v.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
||||||
|
attn_output_v = attn_output_v.transpose(1, 2)
|
||||||
|
attn_output_v = attn_output_v.reshape(bsz, tgt_len, self.embed_dim)
|
||||||
|
|
||||||
|
attn_output_l = attn_output_l.view(bsz, self.num_heads, src_len, self.head_dim)
|
||||||
|
attn_output_l = attn_output_l.transpose(1, 2)
|
||||||
|
attn_output_l = attn_output_l.reshape(bsz, src_len, self.embed_dim)
|
||||||
|
|
||||||
|
attn_output_v = self.out_v_proj(attn_output_v)
|
||||||
|
attn_output_l = self.out_l_proj(attn_output_l)
|
||||||
|
|
||||||
|
return attn_output_v, attn_output_l
|
||||||
|
|
||||||
|
|
||||||
|
# Bi-Direction MHA (text->image, image->text)
|
||||||
|
class BiAttentionBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
v_dim,
|
||||||
|
l_dim,
|
||||||
|
embed_dim,
|
||||||
|
num_heads,
|
||||||
|
dropout=0.1,
|
||||||
|
drop_path=0.0,
|
||||||
|
init_values=1e-4,
|
||||||
|
cfg=None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Inputs:
|
||||||
|
embed_dim - Dimensionality of input and attention feature vectors
|
||||||
|
hidden_dim - Dimensionality of hidden layer in feed-forward network
|
||||||
|
(usually 2-4x larger than embed_dim)
|
||||||
|
num_heads - Number of heads to use in the Multi-Head Attention block
|
||||||
|
dropout - Amount of dropout to apply in the feed-forward network
|
||||||
|
"""
|
||||||
|
super(BiAttentionBlock, self).__init__()
|
||||||
|
|
||||||
|
# pre layer norm
|
||||||
|
self.layer_norm_v = nn.LayerNorm(v_dim)
|
||||||
|
self.layer_norm_l = nn.LayerNorm(l_dim)
|
||||||
|
self.attn = BiMultiHeadAttention(
|
||||||
|
v_dim=v_dim, l_dim=l_dim, embed_dim=embed_dim, num_heads=num_heads, dropout=dropout
|
||||||
|
)
|
||||||
|
|
||||||
|
# add layer scale for training stability
|
||||||
|
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||||
|
self.gamma_v = nn.Parameter(init_values * torch.ones((v_dim)), requires_grad=True)
|
||||||
|
self.gamma_l = nn.Parameter(init_values * torch.ones((l_dim)), requires_grad=True)
|
||||||
|
|
||||||
|
def forward(self, v, l, attention_mask_v=None, attention_mask_l=None):
|
||||||
|
v = self.layer_norm_v(v)
|
||||||
|
l = self.layer_norm_l(l)
|
||||||
|
delta_v, delta_l = self.attn(
|
||||||
|
v, l, attention_mask_v=attention_mask_v, attention_mask_l=attention_mask_l
|
||||||
|
)
|
||||||
|
# v, l = v + delta_v, l + delta_l
|
||||||
|
v = v + self.drop_path(self.gamma_v * delta_v)
|
||||||
|
l = l + self.drop_path(self.gamma_l * delta_l)
|
||||||
|
return v, l
|
||||||
|
|
||||||
|
# def forward(self, v:List[torch.Tensor], l, attention_mask_v=None, attention_mask_l=None)
|
@ -0,0 +1,394 @@
|
|||||||
|
# ------------------------------------------------------------------------
|
||||||
|
# Grounding DINO
|
||||||
|
# url: https://github.com/IDEA-Research/GroundingDINO
|
||||||
|
# Copyright (c) 2023 IDEA. All Rights Reserved.
|
||||||
|
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
||||||
|
# ------------------------------------------------------------------------
|
||||||
|
# Conditional DETR model and criterion classes.
|
||||||
|
# Copyright (c) 2021 Microsoft. All Rights Reserved.
|
||||||
|
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
||||||
|
# ------------------------------------------------------------------------
|
||||||
|
# Modified from DETR (https://github.com/facebookresearch/detr)
|
||||||
|
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
||||||
|
# ------------------------------------------------------------------------
|
||||||
|
# Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR)
|
||||||
|
# Copyright (c) 2020 SenseTime. All Rights Reserved.
|
||||||
|
# ------------------------------------------------------------------------
|
||||||
|
import copy
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import nn
|
||||||
|
from torchvision.ops.boxes import nms
|
||||||
|
from transformers import AutoTokenizer, BertModel, BertTokenizer, RobertaModel, RobertaTokenizerFast
|
||||||
|
|
||||||
|
from groundingdino.util import box_ops, get_tokenlizer
|
||||||
|
from groundingdino.util.misc import (
|
||||||
|
NestedTensor,
|
||||||
|
accuracy,
|
||||||
|
get_world_size,
|
||||||
|
interpolate,
|
||||||
|
inverse_sigmoid,
|
||||||
|
is_dist_avail_and_initialized,
|
||||||
|
nested_tensor_from_tensor_list,
|
||||||
|
)
|
||||||
|
from groundingdino.util.utils import get_phrases_from_posmap
|
||||||
|
from groundingdino.util.visualizer import COCOVisualizer
|
||||||
|
from groundingdino.util.vl_utils import create_positive_map_from_span
|
||||||
|
|
||||||
|
from ..registry import MODULE_BUILD_FUNCS
|
||||||
|
from .backbone import build_backbone
|
||||||
|
from .bertwarper import (
|
||||||
|
BertModelWarper,
|
||||||
|
generate_masks_with_special_tokens,
|
||||||
|
generate_masks_with_special_tokens_and_transfer_map,
|
||||||
|
)
|
||||||
|
from .transformer import build_transformer
|
||||||
|
from .utils import MLP, ContrastiveEmbed, sigmoid_focal_loss
|
||||||
|
|
||||||
|
|
||||||
|
class GroundingDINO(nn.Module):
|
||||||
|
"""This is the Cross-Attention Detector module that performs object detection"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
backbone,
|
||||||
|
transformer,
|
||||||
|
num_queries,
|
||||||
|
aux_loss=False,
|
||||||
|
iter_update=False,
|
||||||
|
query_dim=2,
|
||||||
|
num_feature_levels=1,
|
||||||
|
nheads=8,
|
||||||
|
# two stage
|
||||||
|
two_stage_type="no", # ['no', 'standard']
|
||||||
|
dec_pred_bbox_embed_share=True,
|
||||||
|
two_stage_class_embed_share=True,
|
||||||
|
two_stage_bbox_embed_share=True,
|
||||||
|
num_patterns=0,
|
||||||
|
dn_number=100,
|
||||||
|
dn_box_noise_scale=0.4,
|
||||||
|
dn_label_noise_ratio=0.5,
|
||||||
|
dn_labelbook_size=100,
|
||||||
|
text_encoder_type="bert-base-uncased",
|
||||||
|
sub_sentence_present=True,
|
||||||
|
max_text_len=256,
|
||||||
|
):
|
||||||
|
"""Initializes the model.
|
||||||
|
Parameters:
|
||||||
|
backbone: torch module of the backbone to be used. See backbone.py
|
||||||
|
transformer: torch module of the transformer architecture. See transformer.py
|
||||||
|
num_queries: number of object queries, ie detection slot. This is the maximal number of objects
|
||||||
|
Conditional DETR can detect in a single image. For COCO, we recommend 100 queries.
|
||||||
|
aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.num_queries = num_queries
|
||||||
|
self.transformer = transformer
|
||||||
|
self.hidden_dim = hidden_dim = transformer.d_model
|
||||||
|
self.num_feature_levels = num_feature_levels
|
||||||
|
self.nheads = nheads
|
||||||
|
self.max_text_len = 256
|
||||||
|
self.sub_sentence_present = sub_sentence_present
|
||||||
|
|
||||||
|
# setting query dim
|
||||||
|
self.query_dim = query_dim
|
||||||
|
assert query_dim == 4
|
||||||
|
|
||||||
|
# for dn training
|
||||||
|
self.num_patterns = num_patterns
|
||||||
|
self.dn_number = dn_number
|
||||||
|
self.dn_box_noise_scale = dn_box_noise_scale
|
||||||
|
self.dn_label_noise_ratio = dn_label_noise_ratio
|
||||||
|
self.dn_labelbook_size = dn_labelbook_size
|
||||||
|
|
||||||
|
# bert
|
||||||
|
self.tokenizer = get_tokenlizer.get_tokenlizer(text_encoder_type)
|
||||||
|
self.bert = get_tokenlizer.get_pretrained_language_model(text_encoder_type)
|
||||||
|
self.bert.pooler.dense.weight.requires_grad_(False)
|
||||||
|
self.bert.pooler.dense.bias.requires_grad_(False)
|
||||||
|
self.bert = BertModelWarper(bert_model=self.bert)
|
||||||
|
|
||||||
|
self.feat_map = nn.Linear(self.bert.config.hidden_size, self.hidden_dim, bias=True)
|
||||||
|
nn.init.constant_(self.feat_map.bias.data, 0)
|
||||||
|
nn.init.xavier_uniform_(self.feat_map.weight.data)
|
||||||
|
# freeze
|
||||||
|
|
||||||
|
# special tokens
|
||||||
|
self.specical_tokens = self.tokenizer.convert_tokens_to_ids(["[CLS]", "[SEP]", ".", "?"])
|
||||||
|
|
||||||
|
# prepare input projection layers
|
||||||
|
if num_feature_levels > 1:
|
||||||
|
num_backbone_outs = len(backbone.num_channels)
|
||||||
|
input_proj_list = []
|
||||||
|
for _ in range(num_backbone_outs):
|
||||||
|
in_channels = backbone.num_channels[_]
|
||||||
|
input_proj_list.append(
|
||||||
|
nn.Sequential(
|
||||||
|
nn.Conv2d(in_channels, hidden_dim, kernel_size=1),
|
||||||
|
nn.GroupNorm(32, hidden_dim),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for _ in range(num_feature_levels - num_backbone_outs):
|
||||||
|
input_proj_list.append(
|
||||||
|
nn.Sequential(
|
||||||
|
nn.Conv2d(in_channels, hidden_dim, kernel_size=3, stride=2, padding=1),
|
||||||
|
nn.GroupNorm(32, hidden_dim),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
in_channels = hidden_dim
|
||||||
|
self.input_proj = nn.ModuleList(input_proj_list)
|
||||||
|
else:
|
||||||
|
assert two_stage_type == "no", "two_stage_type should be no if num_feature_levels=1 !!!"
|
||||||
|
self.input_proj = nn.ModuleList(
|
||||||
|
[
|
||||||
|
nn.Sequential(
|
||||||
|
nn.Conv2d(backbone.num_channels[-1], hidden_dim, kernel_size=1),
|
||||||
|
nn.GroupNorm(32, hidden_dim),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.backbone = backbone
|
||||||
|
self.aux_loss = aux_loss
|
||||||
|
self.box_pred_damping = box_pred_damping = None
|
||||||
|
|
||||||
|
self.iter_update = iter_update
|
||||||
|
assert iter_update, "Why not iter_update?"
|
||||||
|
|
||||||
|
# prepare pred layers
|
||||||
|
self.dec_pred_bbox_embed_share = dec_pred_bbox_embed_share
|
||||||
|
# prepare class & box embed
|
||||||
|
_class_embed = ContrastiveEmbed()
|
||||||
|
|
||||||
|
_bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
|
||||||
|
nn.init.constant_(_bbox_embed.layers[-1].weight.data, 0)
|
||||||
|
nn.init.constant_(_bbox_embed.layers[-1].bias.data, 0)
|
||||||
|
|
||||||
|
if dec_pred_bbox_embed_share:
|
||||||
|
box_embed_layerlist = [_bbox_embed for i in range(transformer.num_decoder_layers)]
|
||||||
|
else:
|
||||||
|
box_embed_layerlist = [
|
||||||
|
copy.deepcopy(_bbox_embed) for i in range(transformer.num_decoder_layers)
|
||||||
|
]
|
||||||
|
class_embed_layerlist = [_class_embed for i in range(transformer.num_decoder_layers)]
|
||||||
|
self.bbox_embed = nn.ModuleList(box_embed_layerlist)
|
||||||
|
self.class_embed = nn.ModuleList(class_embed_layerlist)
|
||||||
|
self.transformer.decoder.bbox_embed = self.bbox_embed
|
||||||
|
self.transformer.decoder.class_embed = self.class_embed
|
||||||
|
|
||||||
|
# two stage
|
||||||
|
self.two_stage_type = two_stage_type
|
||||||
|
assert two_stage_type in ["no", "standard"], "unknown param {} of two_stage_type".format(
|
||||||
|
two_stage_type
|
||||||
|
)
|
||||||
|
if two_stage_type != "no":
|
||||||
|
if two_stage_bbox_embed_share:
|
||||||
|
assert dec_pred_bbox_embed_share
|
||||||
|
self.transformer.enc_out_bbox_embed = _bbox_embed
|
||||||
|
else:
|
||||||
|
self.transformer.enc_out_bbox_embed = copy.deepcopy(_bbox_embed)
|
||||||
|
|
||||||
|
if two_stage_class_embed_share:
|
||||||
|
assert dec_pred_bbox_embed_share
|
||||||
|
self.transformer.enc_out_class_embed = _class_embed
|
||||||
|
else:
|
||||||
|
self.transformer.enc_out_class_embed = copy.deepcopy(_class_embed)
|
||||||
|
|
||||||
|
self.refpoint_embed = None
|
||||||
|
|
||||||
|
self._reset_parameters()
|
||||||
|
|
||||||
|
def _reset_parameters(self):
|
||||||
|
# init input_proj
|
||||||
|
for proj in self.input_proj:
|
||||||
|
nn.init.xavier_uniform_(proj[0].weight, gain=1)
|
||||||
|
nn.init.constant_(proj[0].bias, 0)
|
||||||
|
|
||||||
|
def init_ref_points(self, use_num_queries):
|
||||||
|
self.refpoint_embed = nn.Embedding(use_num_queries, self.query_dim)
|
||||||
|
|
||||||
|
def forward(self, samples: NestedTensor, targets: List = None, **kw):
|
||||||
|
"""The forward expects a NestedTensor, which consists of:
|
||||||
|
- samples.tensor: batched images, of shape [batch_size x 3 x H x W]
|
||||||
|
- samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels
|
||||||
|
|
||||||
|
It returns a dict with the following elements:
|
||||||
|
- "pred_logits": the classification logits (including no-object) for all queries.
|
||||||
|
Shape= [batch_size x num_queries x num_classes]
|
||||||
|
- "pred_boxes": The normalized boxes coordinates for all queries, represented as
|
||||||
|
(center_x, center_y, width, height). These values are normalized in [0, 1],
|
||||||
|
relative to the size of each individual image (disregarding possible padding).
|
||||||
|
See PostProcess for information on how to retrieve the unnormalized bounding box.
|
||||||
|
- "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of
|
||||||
|
dictionnaries containing the two above keys for each decoder layer.
|
||||||
|
"""
|
||||||
|
if targets is None:
|
||||||
|
captions = kw["captions"]
|
||||||
|
else:
|
||||||
|
captions = [t["caption"] for t in targets]
|
||||||
|
|
||||||
|
# encoder texts
|
||||||
|
tokenized = self.tokenizer(captions, padding="longest", return_tensors="pt").to(
|
||||||
|
samples.device
|
||||||
|
)
|
||||||
|
(
|
||||||
|
text_self_attention_masks,
|
||||||
|
position_ids,
|
||||||
|
cate_to_token_mask_list,
|
||||||
|
) = generate_masks_with_special_tokens_and_transfer_map(
|
||||||
|
tokenized, self.specical_tokens, self.tokenizer
|
||||||
|
)
|
||||||
|
|
||||||
|
if text_self_attention_masks.shape[1] > self.max_text_len:
|
||||||
|
text_self_attention_masks = text_self_attention_masks[
|
||||||
|
:, : self.max_text_len, : self.max_text_len
|
||||||
|
]
|
||||||
|
position_ids = position_ids[:, : self.max_text_len]
|
||||||
|
tokenized["input_ids"] = tokenized["input_ids"][:, : self.max_text_len]
|
||||||
|
tokenized["attention_mask"] = tokenized["attention_mask"][:, : self.max_text_len]
|
||||||
|
tokenized["token_type_ids"] = tokenized["token_type_ids"][:, : self.max_text_len]
|
||||||
|
|
||||||
|
# extract text embeddings
|
||||||
|
if self.sub_sentence_present:
|
||||||
|
tokenized_for_encoder = {k: v for k, v in tokenized.items() if k != "attention_mask"}
|
||||||
|
tokenized_for_encoder["attention_mask"] = text_self_attention_masks
|
||||||
|
tokenized_for_encoder["position_ids"] = position_ids
|
||||||
|
else:
|
||||||
|
# import ipdb; ipdb.set_trace()
|
||||||
|
tokenized_for_encoder = tokenized
|
||||||
|
|
||||||
|
bert_output = self.bert(**tokenized_for_encoder) # bs, 195, 768
|
||||||
|
|
||||||
|
encoded_text = self.feat_map(bert_output["last_hidden_state"]) # bs, 195, d_model
|
||||||
|
text_token_mask = tokenized.attention_mask.bool() # bs, 195
|
||||||
|
# text_token_mask: True for nomask, False for mask
|
||||||
|
# text_self_attention_masks: True for nomask, False for mask
|
||||||
|
|
||||||
|
if encoded_text.shape[1] > self.max_text_len:
|
||||||
|
encoded_text = encoded_text[:, : self.max_text_len, :]
|
||||||
|
text_token_mask = text_token_mask[:, : self.max_text_len]
|
||||||
|
position_ids = position_ids[:, : self.max_text_len]
|
||||||
|
text_self_attention_masks = text_self_attention_masks[
|
||||||
|
:, : self.max_text_len, : self.max_text_len
|
||||||
|
]
|
||||||
|
|
||||||
|
text_dict = {
|
||||||
|
"encoded_text": encoded_text, # bs, 195, d_model
|
||||||
|
"text_token_mask": text_token_mask, # bs, 195
|
||||||
|
"position_ids": position_ids, # bs, 195
|
||||||
|
"text_self_attention_masks": text_self_attention_masks, # bs, 195,195
|
||||||
|
}
|
||||||
|
|
||||||
|
# import ipdb; ipdb.set_trace()
|
||||||
|
|
||||||
|
if isinstance(samples, (list, torch.Tensor)):
|
||||||
|
samples = nested_tensor_from_tensor_list(samples)
|
||||||
|
features, poss = self.backbone(samples)
|
||||||
|
|
||||||
|
srcs = []
|
||||||
|
masks = []
|
||||||
|
for l, feat in enumerate(features):
|
||||||
|
src, mask = feat.decompose()
|
||||||
|
srcs.append(self.input_proj[l](src))
|
||||||
|
masks.append(mask)
|
||||||
|
assert mask is not None
|
||||||
|
if self.num_feature_levels > len(srcs):
|
||||||
|
_len_srcs = len(srcs)
|
||||||
|
for l in range(_len_srcs, self.num_feature_levels):
|
||||||
|
if l == _len_srcs:
|
||||||
|
src = self.input_proj[l](features[-1].tensors)
|
||||||
|
else:
|
||||||
|
src = self.input_proj[l](srcs[-1])
|
||||||
|
m = samples.mask
|
||||||
|
mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(torch.bool)[0]
|
||||||
|
pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype)
|
||||||
|
srcs.append(src)
|
||||||
|
masks.append(mask)
|
||||||
|
poss.append(pos_l)
|
||||||
|
|
||||||
|
input_query_bbox = input_query_label = attn_mask = dn_meta = None
|
||||||
|
hs, reference, hs_enc, ref_enc, init_box_proposal = self.transformer(
|
||||||
|
srcs, masks, input_query_bbox, poss, input_query_label, attn_mask, text_dict
|
||||||
|
)
|
||||||
|
|
||||||
|
# deformable-detr-like anchor update
|
||||||
|
outputs_coord_list = []
|
||||||
|
for dec_lid, (layer_ref_sig, layer_bbox_embed, layer_hs) in enumerate(
|
||||||
|
zip(reference[:-1], self.bbox_embed, hs)
|
||||||
|
):
|
||||||
|
layer_delta_unsig = layer_bbox_embed(layer_hs)
|
||||||
|
layer_outputs_unsig = layer_delta_unsig + inverse_sigmoid(layer_ref_sig)
|
||||||
|
layer_outputs_unsig = layer_outputs_unsig.sigmoid()
|
||||||
|
outputs_coord_list.append(layer_outputs_unsig)
|
||||||
|
outputs_coord_list = torch.stack(outputs_coord_list)
|
||||||
|
|
||||||
|
# output
|
||||||
|
outputs_class = torch.stack(
|
||||||
|
[
|
||||||
|
layer_cls_embed(layer_hs, text_dict)
|
||||||
|
for layer_cls_embed, layer_hs in zip(self.class_embed, hs)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
out = {"pred_logits": outputs_class[-1], "pred_boxes": outputs_coord_list[-1]}
|
||||||
|
|
||||||
|
# # for intermediate outputs
|
||||||
|
# if self.aux_loss:
|
||||||
|
# out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord_list)
|
||||||
|
|
||||||
|
# # for encoder output
|
||||||
|
# if hs_enc is not None:
|
||||||
|
# # prepare intermediate outputs
|
||||||
|
# interm_coord = ref_enc[-1]
|
||||||
|
# interm_class = self.transformer.enc_out_class_embed(hs_enc[-1], text_dict)
|
||||||
|
# out['interm_outputs'] = {'pred_logits': interm_class, 'pred_boxes': interm_coord}
|
||||||
|
# out['interm_outputs_for_matching_pre'] = {'pred_logits': interm_class, 'pred_boxes': init_box_proposal}
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
@torch.jit.unused
|
||||||
|
def _set_aux_loss(self, outputs_class, outputs_coord):
|
||||||
|
# this is a workaround to make torchscript happy, as torchscript
|
||||||
|
# doesn't support dictionary with non-homogeneous values, such
|
||||||
|
# as a dict having both a Tensor and a list.
|
||||||
|
return [
|
||||||
|
{"pred_logits": a, "pred_boxes": b}
|
||||||
|
for a, b in zip(outputs_class[:-1], outputs_coord[:-1])
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@MODULE_BUILD_FUNCS.registe_with_name(module_name="groundingdino")
|
||||||
|
def build_groundingdino(args):
|
||||||
|
|
||||||
|
backbone = build_backbone(args)
|
||||||
|
transformer = build_transformer(args)
|
||||||
|
|
||||||
|
dn_labelbook_size = args.dn_labelbook_size
|
||||||
|
dec_pred_bbox_embed_share = args.dec_pred_bbox_embed_share
|
||||||
|
sub_sentence_present = args.sub_sentence_present
|
||||||
|
|
||||||
|
model = GroundingDINO(
|
||||||
|
backbone,
|
||||||
|
transformer,
|
||||||
|
num_queries=args.num_queries,
|
||||||
|
aux_loss=True,
|
||||||
|
iter_update=True,
|
||||||
|
query_dim=4,
|
||||||
|
num_feature_levels=args.num_feature_levels,
|
||||||
|
nheads=args.nheads,
|
||||||
|
dec_pred_bbox_embed_share=dec_pred_bbox_embed_share,
|
||||||
|
two_stage_type=args.two_stage_type,
|
||||||
|
two_stage_bbox_embed_share=args.two_stage_bbox_embed_share,
|
||||||
|
two_stage_class_embed_share=args.two_stage_class_embed_share,
|
||||||
|
num_patterns=args.num_patterns,
|
||||||
|
dn_number=0,
|
||||||
|
dn_box_noise_scale=args.dn_box_noise_scale,
|
||||||
|
dn_label_noise_ratio=args.dn_label_noise_ratio,
|
||||||
|
dn_labelbook_size=dn_labelbook_size,
|
||||||
|
text_encoder_type=args.text_encoder_type,
|
||||||
|
sub_sentence_present=sub_sentence_present,
|
||||||
|
max_text_len=args.max_text_len,
|
||||||
|
)
|
||||||
|
|
||||||
|
return model
|
@ -0,0 +1,413 @@
|
|||||||
|
# ------------------------------------------------------------------------
|
||||||
|
# Grounding DINO
|
||||||
|
# url: https://github.com/IDEA-Research/GroundingDINO
|
||||||
|
# Copyright (c) 2023 IDEA. All Rights Reserved.
|
||||||
|
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
||||||
|
# ------------------------------------------------------------------------
|
||||||
|
# Deformable DETR
|
||||||
|
# Copyright (c) 2020 SenseTime. All Rights Reserved.
|
||||||
|
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
||||||
|
# ------------------------------------------------------------------------------------------------
|
||||||
|
# Modified from:
|
||||||
|
# https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/ops/functions/ms_deform_attn_func.py
|
||||||
|
# https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/ops/modules/ms_deform_attn.py
|
||||||
|
# https://github.com/open-mmlab/mmcv/blob/master/mmcv/ops/multi_scale_deform_attn.py
|
||||||
|
# ------------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
import math
|
||||||
|
import warnings
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.autograd import Function
|
||||||
|
from torch.autograd.function import once_differentiable
|
||||||
|
from torch.nn.init import constant_, xavier_uniform_
|
||||||
|
|
||||||
|
try:
|
||||||
|
from groundingdino import _C
|
||||||
|
except:
|
||||||
|
warnings.warn("Failed to load custom C++ ops. Running on CPU mode Only!")
|
||||||
|
|
||||||
|
|
||||||
|
# helpers
|
||||||
|
def _is_power_of_2(n):
|
||||||
|
if (not isinstance(n, int)) or (n < 0):
|
||||||
|
raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n)))
|
||||||
|
return (n & (n - 1) == 0) and n != 0
|
||||||
|
|
||||||
|
|
||||||
|
class MultiScaleDeformableAttnFunction(Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(
|
||||||
|
ctx,
|
||||||
|
value,
|
||||||
|
value_spatial_shapes,
|
||||||
|
value_level_start_index,
|
||||||
|
sampling_locations,
|
||||||
|
attention_weights,
|
||||||
|
im2col_step,
|
||||||
|
):
|
||||||
|
ctx.im2col_step = im2col_step
|
||||||
|
output = _C.ms_deform_attn_forward(
|
||||||
|
value,
|
||||||
|
value_spatial_shapes,
|
||||||
|
value_level_start_index,
|
||||||
|
sampling_locations,
|
||||||
|
attention_weights,
|
||||||
|
ctx.im2col_step,
|
||||||
|
)
|
||||||
|
ctx.save_for_backward(
|
||||||
|
value,
|
||||||
|
value_spatial_shapes,
|
||||||
|
value_level_start_index,
|
||||||
|
sampling_locations,
|
||||||
|
attention_weights,
|
||||||
|
)
|
||||||
|
return output
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@once_differentiable
|
||||||
|
def backward(ctx, grad_output):
|
||||||
|
(
|
||||||
|
value,
|
||||||
|
value_spatial_shapes,
|
||||||
|
value_level_start_index,
|
||||||
|
sampling_locations,
|
||||||
|
attention_weights,
|
||||||
|
) = ctx.saved_tensors
|
||||||
|
grad_value, grad_sampling_loc, grad_attn_weight = _C.ms_deform_attn_backward(
|
||||||
|
value,
|
||||||
|
value_spatial_shapes,
|
||||||
|
value_level_start_index,
|
||||||
|
sampling_locations,
|
||||||
|
attention_weights,
|
||||||
|
grad_output,
|
||||||
|
ctx.im2col_step,
|
||||||
|
)
|
||||||
|
|
||||||
|
return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None
|
||||||
|
|
||||||
|
|
||||||
|
def multi_scale_deformable_attn_pytorch(
|
||||||
|
value: torch.Tensor,
|
||||||
|
value_spatial_shapes: torch.Tensor,
|
||||||
|
sampling_locations: torch.Tensor,
|
||||||
|
attention_weights: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
bs, _, num_heads, embed_dims = value.shape
|
||||||
|
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
|
||||||
|
value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)
|
||||||
|
sampling_grids = 2 * sampling_locations - 1
|
||||||
|
sampling_value_list = []
|
||||||
|
for level, (H_, W_) in enumerate(value_spatial_shapes):
|
||||||
|
# bs, H_*W_, num_heads, embed_dims ->
|
||||||
|
# bs, H_*W_, num_heads*embed_dims ->
|
||||||
|
# bs, num_heads*embed_dims, H_*W_ ->
|
||||||
|
# bs*num_heads, embed_dims, H_, W_
|
||||||
|
value_l_ = (
|
||||||
|
value_list[level].flatten(2).transpose(1, 2).reshape(bs * num_heads, embed_dims, H_, W_)
|
||||||
|
)
|
||||||
|
# bs, num_queries, num_heads, num_points, 2 ->
|
||||||
|
# bs, num_heads, num_queries, num_points, 2 ->
|
||||||
|
# bs*num_heads, num_queries, num_points, 2
|
||||||
|
sampling_grid_l_ = sampling_grids[:, :, :, level].transpose(1, 2).flatten(0, 1)
|
||||||
|
# bs*num_heads, embed_dims, num_queries, num_points
|
||||||
|
sampling_value_l_ = F.grid_sample(
|
||||||
|
value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False
|
||||||
|
)
|
||||||
|
sampling_value_list.append(sampling_value_l_)
|
||||||
|
# (bs, num_queries, num_heads, num_levels, num_points) ->
|
||||||
|
# (bs, num_heads, num_queries, num_levels, num_points) ->
|
||||||
|
# (bs, num_heads, 1, num_queries, num_levels*num_points)
|
||||||
|
attention_weights = attention_weights.transpose(1, 2).reshape(
|
||||||
|
bs * num_heads, 1, num_queries, num_levels * num_points
|
||||||
|
)
|
||||||
|
output = (
|
||||||
|
(torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
|
||||||
|
.sum(-1)
|
||||||
|
.view(bs, num_heads * embed_dims, num_queries)
|
||||||
|
)
|
||||||
|
return output.transpose(1, 2).contiguous()
|
||||||
|
|
||||||
|
|
||||||
|
class MultiScaleDeformableAttention(nn.Module):
|
||||||
|
"""Multi-Scale Deformable Attention Module used in Deformable-DETR
|
||||||
|
|
||||||
|
`Deformable DETR: Deformable Transformers for End-to-End Object Detection.
|
||||||
|
<https://arxiv.org/pdf/2010.04159.pdf>`_.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embed_dim (int): The embedding dimension of Attention. Default: 256.
|
||||||
|
num_heads (int): The number of attention heads. Default: 8.
|
||||||
|
num_levels (int): The number of feature map used in Attention. Default: 4.
|
||||||
|
num_points (int): The number of sampling points for each query
|
||||||
|
in each head. Default: 4.
|
||||||
|
img2col_steps (int): The step used in image_to_column. Defualt: 64.
|
||||||
|
dropout (float): Dropout layer used in output. Default: 0.1.
|
||||||
|
batch_first (bool): if ``True``, then the input and output tensor will be
|
||||||
|
provided as `(bs, n, embed_dim)`. Default: False. `(n, bs, embed_dim)`
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embed_dim: int = 256,
|
||||||
|
num_heads: int = 8,
|
||||||
|
num_levels: int = 4,
|
||||||
|
num_points: int = 4,
|
||||||
|
img2col_step: int = 64,
|
||||||
|
batch_first: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
if embed_dim % num_heads != 0:
|
||||||
|
raise ValueError(
|
||||||
|
"embed_dim must be divisible by num_heads, but got {} and {}".format(
|
||||||
|
embed_dim, num_heads
|
||||||
|
)
|
||||||
|
)
|
||||||
|
head_dim = embed_dim // num_heads
|
||||||
|
|
||||||
|
self.batch_first = batch_first
|
||||||
|
|
||||||
|
if not _is_power_of_2(head_dim):
|
||||||
|
warnings.warn(
|
||||||
|
"""
|
||||||
|
You'd better set d_model in MSDeformAttn to make sure that
|
||||||
|
each dim of the attention head a power of 2, which is more efficient.
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
self.im2col_step = img2col_step
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.num_levels = num_levels
|
||||||
|
self.num_points = num_points
|
||||||
|
self.sampling_offsets = nn.Linear(embed_dim, num_heads * num_levels * num_points * 2)
|
||||||
|
self.attention_weights = nn.Linear(embed_dim, num_heads * num_levels * num_points)
|
||||||
|
self.value_proj = nn.Linear(embed_dim, embed_dim)
|
||||||
|
self.output_proj = nn.Linear(embed_dim, embed_dim)
|
||||||
|
|
||||||
|
self.init_weights()
|
||||||
|
|
||||||
|
def _reset_parameters(self):
|
||||||
|
return self.init_weights()
|
||||||
|
|
||||||
|
def init_weights(self):
|
||||||
|
"""
|
||||||
|
Default initialization for Parameters of Module.
|
||||||
|
"""
|
||||||
|
constant_(self.sampling_offsets.weight.data, 0.0)
|
||||||
|
thetas = torch.arange(self.num_heads, dtype=torch.float32) * (
|
||||||
|
2.0 * math.pi / self.num_heads
|
||||||
|
)
|
||||||
|
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
|
||||||
|
grid_init = (
|
||||||
|
(grid_init / grid_init.abs().max(-1, keepdim=True)[0])
|
||||||
|
.view(self.num_heads, 1, 1, 2)
|
||||||
|
.repeat(1, self.num_levels, self.num_points, 1)
|
||||||
|
)
|
||||||
|
for i in range(self.num_points):
|
||||||
|
grid_init[:, :, i, :] *= i + 1
|
||||||
|
with torch.no_grad():
|
||||||
|
self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
|
||||||
|
constant_(self.attention_weights.weight.data, 0.0)
|
||||||
|
constant_(self.attention_weights.bias.data, 0.0)
|
||||||
|
xavier_uniform_(self.value_proj.weight.data)
|
||||||
|
constant_(self.value_proj.bias.data, 0.0)
|
||||||
|
xavier_uniform_(self.output_proj.weight.data)
|
||||||
|
constant_(self.output_proj.bias.data, 0.0)
|
||||||
|
|
||||||
|
def freeze_sampling_offsets(self):
|
||||||
|
print("Freeze sampling offsets")
|
||||||
|
self.sampling_offsets.weight.requires_grad = False
|
||||||
|
self.sampling_offsets.bias.requires_grad = False
|
||||||
|
|
||||||
|
def freeze_attention_weights(self):
|
||||||
|
print("Freeze attention weights")
|
||||||
|
self.attention_weights.weight.requires_grad = False
|
||||||
|
self.attention_weights.bias.requires_grad = False
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: Optional[torch.Tensor] = None,
|
||||||
|
value: Optional[torch.Tensor] = None,
|
||||||
|
query_pos: Optional[torch.Tensor] = None,
|
||||||
|
key_padding_mask: Optional[torch.Tensor] = None,
|
||||||
|
reference_points: Optional[torch.Tensor] = None,
|
||||||
|
spatial_shapes: Optional[torch.Tensor] = None,
|
||||||
|
level_start_index: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
"""Forward Function of MultiScaleDeformableAttention
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query (torch.Tensor): Query embeddings with shape
|
||||||
|
`(num_query, bs, embed_dim)`
|
||||||
|
key (torch.Tensor): Key embeddings with shape
|
||||||
|
`(num_key, bs, embed_dim)`
|
||||||
|
value (torch.Tensor): Value embeddings with shape
|
||||||
|
`(num_key, bs, embed_dim)`
|
||||||
|
query_pos (torch.Tensor): The position embedding for `query`. Default: None.
|
||||||
|
key_padding_mask (torch.Tensor): ByteTensor for `query`, with shape `(bs, num_key)`,
|
||||||
|
indicating which elements within `key` to be ignored in attention.
|
||||||
|
reference_points (torch.Tensor): The normalized reference points
|
||||||
|
with shape `(bs, num_query, num_levels, 2)`,
|
||||||
|
all elements is range in [0, 1], top-left (0, 0),
|
||||||
|
bottom-right (1, 1), including padding are.
|
||||||
|
or `(N, Length_{query}, num_levels, 4)`, add additional
|
||||||
|
two dimensions `(h, w)` to form reference boxes.
|
||||||
|
spatial_shapes (torch.Tensor): Spatial shape of features in different levels.
|
||||||
|
With shape `(num_levels, 2)`, last dimension represents `(h, w)`.
|
||||||
|
level_start_index (torch.Tensor): The start index of each level. A tensor with
|
||||||
|
shape `(num_levels, )` which can be represented as
|
||||||
|
`[0, h_0 * w_0, h_0 * w_0 + h_1 * w_1, ...]`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: forward results with shape `(num_query, bs, embed_dim)`
|
||||||
|
"""
|
||||||
|
|
||||||
|
if value is None:
|
||||||
|
value = query
|
||||||
|
|
||||||
|
if query_pos is not None:
|
||||||
|
query = query + query_pos
|
||||||
|
|
||||||
|
if not self.batch_first:
|
||||||
|
# change to (bs, num_query ,embed_dims)
|
||||||
|
query = query.permute(1, 0, 2)
|
||||||
|
value = value.permute(1, 0, 2)
|
||||||
|
|
||||||
|
bs, num_query, _ = query.shape
|
||||||
|
bs, num_value, _ = value.shape
|
||||||
|
|
||||||
|
assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_value
|
||||||
|
|
||||||
|
value = self.value_proj(value)
|
||||||
|
if key_padding_mask is not None:
|
||||||
|
value = value.masked_fill(key_padding_mask[..., None], float(0))
|
||||||
|
value = value.view(bs, num_value, self.num_heads, -1)
|
||||||
|
sampling_offsets = self.sampling_offsets(query).view(
|
||||||
|
bs, num_query, self.num_heads, self.num_levels, self.num_points, 2
|
||||||
|
)
|
||||||
|
attention_weights = self.attention_weights(query).view(
|
||||||
|
bs, num_query, self.num_heads, self.num_levels * self.num_points
|
||||||
|
)
|
||||||
|
attention_weights = attention_weights.softmax(-1)
|
||||||
|
attention_weights = attention_weights.view(
|
||||||
|
bs,
|
||||||
|
num_query,
|
||||||
|
self.num_heads,
|
||||||
|
self.num_levels,
|
||||||
|
self.num_points,
|
||||||
|
)
|
||||||
|
|
||||||
|
# bs, num_query, num_heads, num_levels, num_points, 2
|
||||||
|
if reference_points.shape[-1] == 2:
|
||||||
|
offset_normalizer = torch.stack([spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
|
||||||
|
sampling_locations = (
|
||||||
|
reference_points[:, :, None, :, None, :]
|
||||||
|
+ sampling_offsets / offset_normalizer[None, None, None, :, None, :]
|
||||||
|
)
|
||||||
|
elif reference_points.shape[-1] == 4:
|
||||||
|
sampling_locations = (
|
||||||
|
reference_points[:, :, None, :, None, :2]
|
||||||
|
+ sampling_offsets
|
||||||
|
/ self.num_points
|
||||||
|
* reference_points[:, :, None, :, None, 2:]
|
||||||
|
* 0.5
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Last dim of reference_points must be 2 or 4, but get {} instead.".format(
|
||||||
|
reference_points.shape[-1]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if torch.cuda.is_available() and value.is_cuda:
|
||||||
|
halffloat = False
|
||||||
|
if value.dtype == torch.float16:
|
||||||
|
halffloat = True
|
||||||
|
value = value.float()
|
||||||
|
sampling_locations = sampling_locations.float()
|
||||||
|
attention_weights = attention_weights.float()
|
||||||
|
|
||||||
|
output = MultiScaleDeformableAttnFunction.apply(
|
||||||
|
value,
|
||||||
|
spatial_shapes,
|
||||||
|
level_start_index,
|
||||||
|
sampling_locations,
|
||||||
|
attention_weights,
|
||||||
|
self.im2col_step,
|
||||||
|
)
|
||||||
|
|
||||||
|
if halffloat:
|
||||||
|
output = output.half()
|
||||||
|
else:
|
||||||
|
output = multi_scale_deformable_attn_pytorch(
|
||||||
|
value, spatial_shapes, sampling_locations, attention_weights
|
||||||
|
)
|
||||||
|
|
||||||
|
output = self.output_proj(output)
|
||||||
|
|
||||||
|
if not self.batch_first:
|
||||||
|
output = output.permute(1, 0, 2)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def create_dummy_class(klass, dependency, message=""):
|
||||||
|
"""
|
||||||
|
When a dependency of a class is not available, create a dummy class which throws ImportError
|
||||||
|
when used.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
klass (str): name of the class.
|
||||||
|
dependency (str): name of the dependency.
|
||||||
|
message: extra message to print
|
||||||
|
Returns:
|
||||||
|
class: a class object
|
||||||
|
"""
|
||||||
|
err = "Cannot import '{}', therefore '{}' is not available.".format(dependency, klass)
|
||||||
|
if message:
|
||||||
|
err = err + " " + message
|
||||||
|
|
||||||
|
class _DummyMetaClass(type):
|
||||||
|
# throw error on class attribute access
|
||||||
|
def __getattr__(_, __): # noqa: B902
|
||||||
|
raise ImportError(err)
|
||||||
|
|
||||||
|
class _Dummy(object, metaclass=_DummyMetaClass):
|
||||||
|
# throw error on constructor
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
raise ImportError(err)
|
||||||
|
|
||||||
|
return _Dummy
|
||||||
|
|
||||||
|
|
||||||
|
def create_dummy_func(func, dependency, message=""):
|
||||||
|
"""
|
||||||
|
When a dependency of a function is not available, create a dummy function which throws
|
||||||
|
ImportError when used.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func (str): name of the function.
|
||||||
|
dependency (str or list[str]): name(s) of the dependency.
|
||||||
|
message: extra message to print
|
||||||
|
Returns:
|
||||||
|
function: a function object
|
||||||
|
"""
|
||||||
|
err = "Cannot import '{}', therefore '{}' is not available.".format(dependency, func)
|
||||||
|
if message:
|
||||||
|
err = err + " " + message
|
||||||
|
|
||||||
|
if isinstance(dependency, (list, tuple)):
|
||||||
|
dependency = ",".join(dependency)
|
||||||
|
|
||||||
|
def _dummy(*args, **kwargs):
|
||||||
|
raise ImportError(err)
|
||||||
|
|
||||||
|
return _dummy
|
@ -0,0 +1,959 @@
|
|||||||
|
# ------------------------------------------------------------------------
|
||||||
|
# Grounding DINO
|
||||||
|
# url: https://github.com/IDEA-Research/GroundingDINO
|
||||||
|
# Copyright (c) 2023 IDEA. All Rights Reserved.
|
||||||
|
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
||||||
|
# ------------------------------------------------------------------------
|
||||||
|
# DINO
|
||||||
|
# Copyright (c) 2022 IDEA. All Rights Reserved.
|
||||||
|
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
||||||
|
# ------------------------------------------------------------------------
|
||||||
|
# Conditional DETR Transformer class.
|
||||||
|
# Copyright (c) 2021 Microsoft. All Rights Reserved.
|
||||||
|
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
||||||
|
# ------------------------------------------------------------------------
|
||||||
|
# Modified from DETR (https://github.com/facebookresearch/detr)
|
||||||
|
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
||||||
|
# ------------------------------------------------------------------------
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.utils.checkpoint as checkpoint
|
||||||
|
from torch import Tensor, nn
|
||||||
|
|
||||||
|
from groundingdino.util.misc import inverse_sigmoid
|
||||||
|
|
||||||
|
from .fuse_modules import BiAttentionBlock
|
||||||
|
from .ms_deform_attn import MultiScaleDeformableAttention as MSDeformAttn
|
||||||
|
from .transformer_vanilla import TransformerEncoderLayer
|
||||||
|
from .utils import (
|
||||||
|
MLP,
|
||||||
|
_get_activation_fn,
|
||||||
|
_get_clones,
|
||||||
|
gen_encoder_output_proposals,
|
||||||
|
gen_sineembed_for_position,
|
||||||
|
get_sine_pos_embed,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Transformer(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
d_model=256,
|
||||||
|
nhead=8,
|
||||||
|
num_queries=300,
|
||||||
|
num_encoder_layers=6,
|
||||||
|
num_unicoder_layers=0,
|
||||||
|
num_decoder_layers=6,
|
||||||
|
dim_feedforward=2048,
|
||||||
|
dropout=0.0,
|
||||||
|
activation="relu",
|
||||||
|
normalize_before=False,
|
||||||
|
return_intermediate_dec=False,
|
||||||
|
query_dim=4,
|
||||||
|
num_patterns=0,
|
||||||
|
# for deformable encoder
|
||||||
|
num_feature_levels=1,
|
||||||
|
enc_n_points=4,
|
||||||
|
dec_n_points=4,
|
||||||
|
# init query
|
||||||
|
learnable_tgt_init=False,
|
||||||
|
# two stage
|
||||||
|
two_stage_type="no", # ['no', 'standard', 'early', 'combine', 'enceachlayer', 'enclayer1']
|
||||||
|
embed_init_tgt=False,
|
||||||
|
# for text
|
||||||
|
use_text_enhancer=False,
|
||||||
|
use_fusion_layer=False,
|
||||||
|
use_checkpoint=False,
|
||||||
|
use_transformer_ckpt=False,
|
||||||
|
use_text_cross_attention=False,
|
||||||
|
text_dropout=0.1,
|
||||||
|
fusion_dropout=0.1,
|
||||||
|
fusion_droppath=0.0,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.num_feature_levels = num_feature_levels
|
||||||
|
self.num_encoder_layers = num_encoder_layers
|
||||||
|
self.num_unicoder_layers = num_unicoder_layers
|
||||||
|
self.num_decoder_layers = num_decoder_layers
|
||||||
|
self.num_queries = num_queries
|
||||||
|
assert query_dim == 4
|
||||||
|
|
||||||
|
# choose encoder layer type
|
||||||
|
encoder_layer = DeformableTransformerEncoderLayer(
|
||||||
|
d_model, dim_feedforward, dropout, activation, num_feature_levels, nhead, enc_n_points
|
||||||
|
)
|
||||||
|
|
||||||
|
if use_text_enhancer:
|
||||||
|
text_enhance_layer = TransformerEncoderLayer(
|
||||||
|
d_model=d_model,
|
||||||
|
nhead=nhead // 2,
|
||||||
|
dim_feedforward=dim_feedforward // 2,
|
||||||
|
dropout=text_dropout,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
text_enhance_layer = None
|
||||||
|
|
||||||
|
if use_fusion_layer:
|
||||||
|
feature_fusion_layer = BiAttentionBlock(
|
||||||
|
v_dim=d_model,
|
||||||
|
l_dim=d_model,
|
||||||
|
embed_dim=dim_feedforward // 2,
|
||||||
|
num_heads=nhead // 2,
|
||||||
|
dropout=fusion_dropout,
|
||||||
|
drop_path=fusion_droppath,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
feature_fusion_layer = None
|
||||||
|
|
||||||
|
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
|
||||||
|
assert encoder_norm is None
|
||||||
|
self.encoder = TransformerEncoder(
|
||||||
|
encoder_layer,
|
||||||
|
num_encoder_layers,
|
||||||
|
d_model=d_model,
|
||||||
|
num_queries=num_queries,
|
||||||
|
text_enhance_layer=text_enhance_layer,
|
||||||
|
feature_fusion_layer=feature_fusion_layer,
|
||||||
|
use_checkpoint=use_checkpoint,
|
||||||
|
use_transformer_ckpt=use_transformer_ckpt,
|
||||||
|
)
|
||||||
|
|
||||||
|
# choose decoder layer type
|
||||||
|
decoder_layer = DeformableTransformerDecoderLayer(
|
||||||
|
d_model,
|
||||||
|
dim_feedforward,
|
||||||
|
dropout,
|
||||||
|
activation,
|
||||||
|
num_feature_levels,
|
||||||
|
nhead,
|
||||||
|
dec_n_points,
|
||||||
|
use_text_cross_attention=use_text_cross_attention,
|
||||||
|
)
|
||||||
|
|
||||||
|
decoder_norm = nn.LayerNorm(d_model)
|
||||||
|
self.decoder = TransformerDecoder(
|
||||||
|
decoder_layer,
|
||||||
|
num_decoder_layers,
|
||||||
|
decoder_norm,
|
||||||
|
return_intermediate=return_intermediate_dec,
|
||||||
|
d_model=d_model,
|
||||||
|
query_dim=query_dim,
|
||||||
|
num_feature_levels=num_feature_levels,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.d_model = d_model
|
||||||
|
self.nhead = nhead
|
||||||
|
self.dec_layers = num_decoder_layers
|
||||||
|
self.num_queries = num_queries # useful for single stage model only
|
||||||
|
self.num_patterns = num_patterns
|
||||||
|
if not isinstance(num_patterns, int):
|
||||||
|
Warning("num_patterns should be int but {}".format(type(num_patterns)))
|
||||||
|
self.num_patterns = 0
|
||||||
|
|
||||||
|
if num_feature_levels > 1:
|
||||||
|
if self.num_encoder_layers > 0:
|
||||||
|
self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model))
|
||||||
|
else:
|
||||||
|
self.level_embed = None
|
||||||
|
|
||||||
|
self.learnable_tgt_init = learnable_tgt_init
|
||||||
|
assert learnable_tgt_init, "why not learnable_tgt_init"
|
||||||
|
self.embed_init_tgt = embed_init_tgt
|
||||||
|
if (two_stage_type != "no" and embed_init_tgt) or (two_stage_type == "no"):
|
||||||
|
self.tgt_embed = nn.Embedding(self.num_queries, d_model)
|
||||||
|
nn.init.normal_(self.tgt_embed.weight.data)
|
||||||
|
else:
|
||||||
|
self.tgt_embed = None
|
||||||
|
|
||||||
|
# for two stage
|
||||||
|
self.two_stage_type = two_stage_type
|
||||||
|
assert two_stage_type in ["no", "standard"], "unknown param {} of two_stage_type".format(
|
||||||
|
two_stage_type
|
||||||
|
)
|
||||||
|
if two_stage_type == "standard":
|
||||||
|
# anchor selection at the output of encoder
|
||||||
|
self.enc_output = nn.Linear(d_model, d_model)
|
||||||
|
self.enc_output_norm = nn.LayerNorm(d_model)
|
||||||
|
self.two_stage_wh_embedding = None
|
||||||
|
|
||||||
|
if two_stage_type == "no":
|
||||||
|
self.init_ref_points(num_queries) # init self.refpoint_embed
|
||||||
|
|
||||||
|
self.enc_out_class_embed = None
|
||||||
|
self.enc_out_bbox_embed = None
|
||||||
|
|
||||||
|
self._reset_parameters()
|
||||||
|
|
||||||
|
def _reset_parameters(self):
|
||||||
|
for p in self.parameters():
|
||||||
|
if p.dim() > 1:
|
||||||
|
nn.init.xavier_uniform_(p)
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, MSDeformAttn):
|
||||||
|
m._reset_parameters()
|
||||||
|
if self.num_feature_levels > 1 and self.level_embed is not None:
|
||||||
|
nn.init.normal_(self.level_embed)
|
||||||
|
|
||||||
|
def get_valid_ratio(self, mask):
|
||||||
|
_, H, W = mask.shape
|
||||||
|
valid_H = torch.sum(~mask[:, :, 0], 1)
|
||||||
|
valid_W = torch.sum(~mask[:, 0, :], 1)
|
||||||
|
valid_ratio_h = valid_H.float() / H
|
||||||
|
valid_ratio_w = valid_W.float() / W
|
||||||
|
valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
|
||||||
|
return valid_ratio
|
||||||
|
|
||||||
|
def init_ref_points(self, use_num_queries):
|
||||||
|
self.refpoint_embed = nn.Embedding(use_num_queries, 4)
|
||||||
|
|
||||||
|
def forward(self, srcs, masks, refpoint_embed, pos_embeds, tgt, attn_mask=None, text_dict=None):
|
||||||
|
"""
|
||||||
|
Input:
|
||||||
|
- srcs: List of multi features [bs, ci, hi, wi]
|
||||||
|
- masks: List of multi masks [bs, hi, wi]
|
||||||
|
- refpoint_embed: [bs, num_dn, 4]. None in infer
|
||||||
|
- pos_embeds: List of multi pos embeds [bs, ci, hi, wi]
|
||||||
|
- tgt: [bs, num_dn, d_model]. None in infer
|
||||||
|
|
||||||
|
"""
|
||||||
|
# prepare input for encoder
|
||||||
|
src_flatten = []
|
||||||
|
mask_flatten = []
|
||||||
|
lvl_pos_embed_flatten = []
|
||||||
|
spatial_shapes = []
|
||||||
|
for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)):
|
||||||
|
bs, c, h, w = src.shape
|
||||||
|
spatial_shape = (h, w)
|
||||||
|
spatial_shapes.append(spatial_shape)
|
||||||
|
|
||||||
|
src = src.flatten(2).transpose(1, 2) # bs, hw, c
|
||||||
|
mask = mask.flatten(1) # bs, hw
|
||||||
|
pos_embed = pos_embed.flatten(2).transpose(1, 2) # bs, hw, c
|
||||||
|
if self.num_feature_levels > 1 and self.level_embed is not None:
|
||||||
|
lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1)
|
||||||
|
else:
|
||||||
|
lvl_pos_embed = pos_embed
|
||||||
|
lvl_pos_embed_flatten.append(lvl_pos_embed)
|
||||||
|
src_flatten.append(src)
|
||||||
|
mask_flatten.append(mask)
|
||||||
|
src_flatten = torch.cat(src_flatten, 1) # bs, \sum{hxw}, c
|
||||||
|
mask_flatten = torch.cat(mask_flatten, 1) # bs, \sum{hxw}
|
||||||
|
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) # bs, \sum{hxw}, c
|
||||||
|
spatial_shapes = torch.as_tensor(
|
||||||
|
spatial_shapes, dtype=torch.long, device=src_flatten.device
|
||||||
|
)
|
||||||
|
level_start_index = torch.cat(
|
||||||
|
(spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])
|
||||||
|
)
|
||||||
|
valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)
|
||||||
|
|
||||||
|
# two stage
|
||||||
|
enc_topk_proposals = enc_refpoint_embed = None
|
||||||
|
|
||||||
|
#########################################################
|
||||||
|
# Begin Encoder
|
||||||
|
#########################################################
|
||||||
|
memory, memory_text = self.encoder(
|
||||||
|
src_flatten,
|
||||||
|
pos=lvl_pos_embed_flatten,
|
||||||
|
level_start_index=level_start_index,
|
||||||
|
spatial_shapes=spatial_shapes,
|
||||||
|
valid_ratios=valid_ratios,
|
||||||
|
key_padding_mask=mask_flatten,
|
||||||
|
memory_text=text_dict["encoded_text"],
|
||||||
|
text_attention_mask=~text_dict["text_token_mask"],
|
||||||
|
# we ~ the mask . False means use the token; True means pad the token
|
||||||
|
position_ids=text_dict["position_ids"],
|
||||||
|
text_self_attention_masks=text_dict["text_self_attention_masks"],
|
||||||
|
)
|
||||||
|
#########################################################
|
||||||
|
# End Encoder
|
||||||
|
# - memory: bs, \sum{hw}, c
|
||||||
|
# - mask_flatten: bs, \sum{hw}
|
||||||
|
# - lvl_pos_embed_flatten: bs, \sum{hw}, c
|
||||||
|
# - enc_intermediate_output: None or (nenc+1, bs, nq, c) or (nenc, bs, nq, c)
|
||||||
|
# - enc_intermediate_refpoints: None or (nenc+1, bs, nq, c) or (nenc, bs, nq, c)
|
||||||
|
#########################################################
|
||||||
|
text_dict["encoded_text"] = memory_text
|
||||||
|
# if os.environ.get("SHILONG_AMP_INFNAN_DEBUG") == '1':
|
||||||
|
# if memory.isnan().any() | memory.isinf().any():
|
||||||
|
# import ipdb; ipdb.set_trace()
|
||||||
|
|
||||||
|
if self.two_stage_type == "standard":
|
||||||
|
output_memory, output_proposals = gen_encoder_output_proposals(
|
||||||
|
memory, mask_flatten, spatial_shapes
|
||||||
|
)
|
||||||
|
output_memory = self.enc_output_norm(self.enc_output(output_memory))
|
||||||
|
|
||||||
|
if text_dict is not None:
|
||||||
|
enc_outputs_class_unselected = self.enc_out_class_embed(output_memory, text_dict)
|
||||||
|
else:
|
||||||
|
enc_outputs_class_unselected = self.enc_out_class_embed(output_memory)
|
||||||
|
|
||||||
|
topk_logits = enc_outputs_class_unselected.max(-1)[0]
|
||||||
|
enc_outputs_coord_unselected = (
|
||||||
|
self.enc_out_bbox_embed(output_memory) + output_proposals
|
||||||
|
) # (bs, \sum{hw}, 4) unsigmoid
|
||||||
|
topk = self.num_queries
|
||||||
|
|
||||||
|
topk_proposals = torch.topk(topk_logits, topk, dim=1)[1] # bs, nq
|
||||||
|
|
||||||
|
# gather boxes
|
||||||
|
refpoint_embed_undetach = torch.gather(
|
||||||
|
enc_outputs_coord_unselected, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)
|
||||||
|
) # unsigmoid
|
||||||
|
refpoint_embed_ = refpoint_embed_undetach.detach()
|
||||||
|
init_box_proposal = torch.gather(
|
||||||
|
output_proposals, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)
|
||||||
|
).sigmoid() # sigmoid
|
||||||
|
|
||||||
|
# gather tgt
|
||||||
|
tgt_undetach = torch.gather(
|
||||||
|
output_memory, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, self.d_model)
|
||||||
|
)
|
||||||
|
if self.embed_init_tgt:
|
||||||
|
tgt_ = (
|
||||||
|
self.tgt_embed.weight[:, None, :].repeat(1, bs, 1).transpose(0, 1)
|
||||||
|
) # nq, bs, d_model
|
||||||
|
else:
|
||||||
|
tgt_ = tgt_undetach.detach()
|
||||||
|
|
||||||
|
if refpoint_embed is not None:
|
||||||
|
refpoint_embed = torch.cat([refpoint_embed, refpoint_embed_], dim=1)
|
||||||
|
tgt = torch.cat([tgt, tgt_], dim=1)
|
||||||
|
else:
|
||||||
|
refpoint_embed, tgt = refpoint_embed_, tgt_
|
||||||
|
|
||||||
|
elif self.two_stage_type == "no":
|
||||||
|
tgt_ = (
|
||||||
|
self.tgt_embed.weight[:, None, :].repeat(1, bs, 1).transpose(0, 1)
|
||||||
|
) # nq, bs, d_model
|
||||||
|
refpoint_embed_ = (
|
||||||
|
self.refpoint_embed.weight[:, None, :].repeat(1, bs, 1).transpose(0, 1)
|
||||||
|
) # nq, bs, 4
|
||||||
|
|
||||||
|
if refpoint_embed is not None:
|
||||||
|
refpoint_embed = torch.cat([refpoint_embed, refpoint_embed_], dim=1)
|
||||||
|
tgt = torch.cat([tgt, tgt_], dim=1)
|
||||||
|
else:
|
||||||
|
refpoint_embed, tgt = refpoint_embed_, tgt_
|
||||||
|
|
||||||
|
if self.num_patterns > 0:
|
||||||
|
tgt_embed = tgt.repeat(1, self.num_patterns, 1)
|
||||||
|
refpoint_embed = refpoint_embed.repeat(1, self.num_patterns, 1)
|
||||||
|
tgt_pat = self.patterns.weight[None, :, :].repeat_interleave(
|
||||||
|
self.num_queries, 1
|
||||||
|
) # 1, n_q*n_pat, d_model
|
||||||
|
tgt = tgt_embed + tgt_pat
|
||||||
|
|
||||||
|
init_box_proposal = refpoint_embed_.sigmoid()
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("unknown two_stage_type {}".format(self.two_stage_type))
|
||||||
|
#########################################################
|
||||||
|
# End preparing tgt
|
||||||
|
# - tgt: bs, NQ, d_model
|
||||||
|
# - refpoint_embed(unsigmoid): bs, NQ, d_model
|
||||||
|
#########################################################
|
||||||
|
|
||||||
|
#########################################################
|
||||||
|
# Begin Decoder
|
||||||
|
#########################################################
|
||||||
|
hs, references = self.decoder(
|
||||||
|
tgt=tgt.transpose(0, 1),
|
||||||
|
memory=memory.transpose(0, 1),
|
||||||
|
memory_key_padding_mask=mask_flatten,
|
||||||
|
pos=lvl_pos_embed_flatten.transpose(0, 1),
|
||||||
|
refpoints_unsigmoid=refpoint_embed.transpose(0, 1),
|
||||||
|
level_start_index=level_start_index,
|
||||||
|
spatial_shapes=spatial_shapes,
|
||||||
|
valid_ratios=valid_ratios,
|
||||||
|
tgt_mask=attn_mask,
|
||||||
|
memory_text=text_dict["encoded_text"],
|
||||||
|
text_attention_mask=~text_dict["text_token_mask"],
|
||||||
|
# we ~ the mask . False means use the token; True means pad the token
|
||||||
|
)
|
||||||
|
#########################################################
|
||||||
|
# End Decoder
|
||||||
|
# hs: n_dec, bs, nq, d_model
|
||||||
|
# references: n_dec+1, bs, nq, query_dim
|
||||||
|
#########################################################
|
||||||
|
|
||||||
|
#########################################################
|
||||||
|
# Begin postprocess
|
||||||
|
#########################################################
|
||||||
|
if self.two_stage_type == "standard":
|
||||||
|
hs_enc = tgt_undetach.unsqueeze(0)
|
||||||
|
ref_enc = refpoint_embed_undetach.sigmoid().unsqueeze(0)
|
||||||
|
else:
|
||||||
|
hs_enc = ref_enc = None
|
||||||
|
#########################################################
|
||||||
|
# End postprocess
|
||||||
|
# hs_enc: (n_enc+1, bs, nq, d_model) or (1, bs, nq, d_model) or (n_enc, bs, nq, d_model) or None
|
||||||
|
# ref_enc: (n_enc+1, bs, nq, query_dim) or (1, bs, nq, query_dim) or (n_enc, bs, nq, d_model) or None
|
||||||
|
#########################################################
|
||||||
|
|
||||||
|
return hs, references, hs_enc, ref_enc, init_box_proposal
|
||||||
|
# hs: (n_dec, bs, nq, d_model)
|
||||||
|
# references: sigmoid coordinates. (n_dec+1, bs, bq, 4)
|
||||||
|
# hs_enc: (n_enc+1, bs, nq, d_model) or (1, bs, nq, d_model) or None
|
||||||
|
# ref_enc: sigmoid coordinates. \
|
||||||
|
# (n_enc+1, bs, nq, query_dim) or (1, bs, nq, query_dim) or None
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerEncoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
encoder_layer,
|
||||||
|
num_layers,
|
||||||
|
d_model=256,
|
||||||
|
num_queries=300,
|
||||||
|
enc_layer_share=False,
|
||||||
|
text_enhance_layer=None,
|
||||||
|
feature_fusion_layer=None,
|
||||||
|
use_checkpoint=False,
|
||||||
|
use_transformer_ckpt=False,
|
||||||
|
):
|
||||||
|
"""_summary_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
encoder_layer (_type_): _description_
|
||||||
|
num_layers (_type_): _description_
|
||||||
|
norm (_type_, optional): _description_. Defaults to None.
|
||||||
|
d_model (int, optional): _description_. Defaults to 256.
|
||||||
|
num_queries (int, optional): _description_. Defaults to 300.
|
||||||
|
enc_layer_share (bool, optional): _description_. Defaults to False.
|
||||||
|
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
# prepare layers
|
||||||
|
self.layers = []
|
||||||
|
self.text_layers = []
|
||||||
|
self.fusion_layers = []
|
||||||
|
if num_layers > 0:
|
||||||
|
self.layers = _get_clones(encoder_layer, num_layers, layer_share=enc_layer_share)
|
||||||
|
|
||||||
|
if text_enhance_layer is not None:
|
||||||
|
self.text_layers = _get_clones(
|
||||||
|
text_enhance_layer, num_layers, layer_share=enc_layer_share
|
||||||
|
)
|
||||||
|
if feature_fusion_layer is not None:
|
||||||
|
self.fusion_layers = _get_clones(
|
||||||
|
feature_fusion_layer, num_layers, layer_share=enc_layer_share
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.layers = []
|
||||||
|
del encoder_layer
|
||||||
|
|
||||||
|
if text_enhance_layer is not None:
|
||||||
|
self.text_layers = []
|
||||||
|
del text_enhance_layer
|
||||||
|
if feature_fusion_layer is not None:
|
||||||
|
self.fusion_layers = []
|
||||||
|
del feature_fusion_layer
|
||||||
|
|
||||||
|
self.query_scale = None
|
||||||
|
self.num_queries = num_queries
|
||||||
|
self.num_layers = num_layers
|
||||||
|
self.d_model = d_model
|
||||||
|
|
||||||
|
self.use_checkpoint = use_checkpoint
|
||||||
|
self.use_transformer_ckpt = use_transformer_ckpt
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_reference_points(spatial_shapes, valid_ratios, device):
|
||||||
|
reference_points_list = []
|
||||||
|
for lvl, (H_, W_) in enumerate(spatial_shapes):
|
||||||
|
|
||||||
|
ref_y, ref_x = torch.meshgrid(
|
||||||
|
torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
|
||||||
|
torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device),
|
||||||
|
)
|
||||||
|
ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_)
|
||||||
|
ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_)
|
||||||
|
ref = torch.stack((ref_x, ref_y), -1)
|
||||||
|
reference_points_list.append(ref)
|
||||||
|
reference_points = torch.cat(reference_points_list, 1)
|
||||||
|
reference_points = reference_points[:, :, None] * valid_ratios[:, None]
|
||||||
|
return reference_points
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
# for images
|
||||||
|
src: Tensor,
|
||||||
|
pos: Tensor,
|
||||||
|
spatial_shapes: Tensor,
|
||||||
|
level_start_index: Tensor,
|
||||||
|
valid_ratios: Tensor,
|
||||||
|
key_padding_mask: Tensor,
|
||||||
|
# for texts
|
||||||
|
memory_text: Tensor = None,
|
||||||
|
text_attention_mask: Tensor = None,
|
||||||
|
pos_text: Tensor = None,
|
||||||
|
text_self_attention_masks: Tensor = None,
|
||||||
|
position_ids: Tensor = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Input:
|
||||||
|
- src: [bs, sum(hi*wi), 256]
|
||||||
|
- pos: pos embed for src. [bs, sum(hi*wi), 256]
|
||||||
|
- spatial_shapes: h,w of each level [num_level, 2]
|
||||||
|
- level_start_index: [num_level] start point of level in sum(hi*wi).
|
||||||
|
- valid_ratios: [bs, num_level, 2]
|
||||||
|
- key_padding_mask: [bs, sum(hi*wi)]
|
||||||
|
|
||||||
|
- memory_text: bs, n_text, 256
|
||||||
|
- text_attention_mask: bs, n_text
|
||||||
|
False for no padding; True for padding
|
||||||
|
- pos_text: bs, n_text, 256
|
||||||
|
|
||||||
|
- position_ids: bs, n_text
|
||||||
|
Intermedia:
|
||||||
|
- reference_points: [bs, sum(hi*wi), num_level, 2]
|
||||||
|
Outpus:
|
||||||
|
- output: [bs, sum(hi*wi), 256]
|
||||||
|
"""
|
||||||
|
|
||||||
|
output = src
|
||||||
|
|
||||||
|
# preparation and reshape
|
||||||
|
if self.num_layers > 0:
|
||||||
|
reference_points = self.get_reference_points(
|
||||||
|
spatial_shapes, valid_ratios, device=src.device
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.text_layers:
|
||||||
|
# generate pos_text
|
||||||
|
bs, n_text, text_dim = memory_text.shape
|
||||||
|
if pos_text is None and position_ids is None:
|
||||||
|
pos_text = (
|
||||||
|
torch.arange(n_text, device=memory_text.device)
|
||||||
|
.float()
|
||||||
|
.unsqueeze(0)
|
||||||
|
.unsqueeze(-1)
|
||||||
|
.repeat(bs, 1, 1)
|
||||||
|
)
|
||||||
|
pos_text = get_sine_pos_embed(pos_text, num_pos_feats=256, exchange_xy=False)
|
||||||
|
if position_ids is not None:
|
||||||
|
pos_text = get_sine_pos_embed(
|
||||||
|
position_ids[..., None], num_pos_feats=256, exchange_xy=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# main process
|
||||||
|
for layer_id, layer in enumerate(self.layers):
|
||||||
|
# if output.isnan().any() or memory_text.isnan().any():
|
||||||
|
# if os.environ.get('IPDB_SHILONG_DEBUG', None) == 'INFO':
|
||||||
|
# import ipdb; ipdb.set_trace()
|
||||||
|
if self.fusion_layers:
|
||||||
|
if self.use_checkpoint:
|
||||||
|
output, memory_text = checkpoint.checkpoint(
|
||||||
|
self.fusion_layers[layer_id],
|
||||||
|
output,
|
||||||
|
memory_text,
|
||||||
|
key_padding_mask,
|
||||||
|
text_attention_mask,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
output, memory_text = self.fusion_layers[layer_id](
|
||||||
|
v=output,
|
||||||
|
l=memory_text,
|
||||||
|
attention_mask_v=key_padding_mask,
|
||||||
|
attention_mask_l=text_attention_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.text_layers:
|
||||||
|
memory_text = self.text_layers[layer_id](
|
||||||
|
src=memory_text.transpose(0, 1),
|
||||||
|
src_mask=~text_self_attention_masks, # note we use ~ for mask here
|
||||||
|
src_key_padding_mask=text_attention_mask,
|
||||||
|
pos=(pos_text.transpose(0, 1) if pos_text is not None else None),
|
||||||
|
).transpose(0, 1)
|
||||||
|
|
||||||
|
# main process
|
||||||
|
if self.use_transformer_ckpt:
|
||||||
|
output = checkpoint.checkpoint(
|
||||||
|
layer,
|
||||||
|
output,
|
||||||
|
pos,
|
||||||
|
reference_points,
|
||||||
|
spatial_shapes,
|
||||||
|
level_start_index,
|
||||||
|
key_padding_mask,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
output = layer(
|
||||||
|
src=output,
|
||||||
|
pos=pos,
|
||||||
|
reference_points=reference_points,
|
||||||
|
spatial_shapes=spatial_shapes,
|
||||||
|
level_start_index=level_start_index,
|
||||||
|
key_padding_mask=key_padding_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
return output, memory_text
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerDecoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
decoder_layer,
|
||||||
|
num_layers,
|
||||||
|
norm=None,
|
||||||
|
return_intermediate=False,
|
||||||
|
d_model=256,
|
||||||
|
query_dim=4,
|
||||||
|
num_feature_levels=1,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
if num_layers > 0:
|
||||||
|
self.layers = _get_clones(decoder_layer, num_layers)
|
||||||
|
else:
|
||||||
|
self.layers = []
|
||||||
|
self.num_layers = num_layers
|
||||||
|
self.norm = norm
|
||||||
|
self.return_intermediate = return_intermediate
|
||||||
|
assert return_intermediate, "support return_intermediate only"
|
||||||
|
self.query_dim = query_dim
|
||||||
|
assert query_dim in [2, 4], "query_dim should be 2/4 but {}".format(query_dim)
|
||||||
|
self.num_feature_levels = num_feature_levels
|
||||||
|
|
||||||
|
self.ref_point_head = MLP(query_dim // 2 * d_model, d_model, d_model, 2)
|
||||||
|
self.query_pos_sine_scale = None
|
||||||
|
|
||||||
|
self.query_scale = None
|
||||||
|
self.bbox_embed = None
|
||||||
|
self.class_embed = None
|
||||||
|
|
||||||
|
self.d_model = d_model
|
||||||
|
|
||||||
|
self.ref_anchor_head = None
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
tgt,
|
||||||
|
memory,
|
||||||
|
tgt_mask: Optional[Tensor] = None,
|
||||||
|
memory_mask: Optional[Tensor] = None,
|
||||||
|
tgt_key_padding_mask: Optional[Tensor] = None,
|
||||||
|
memory_key_padding_mask: Optional[Tensor] = None,
|
||||||
|
pos: Optional[Tensor] = None,
|
||||||
|
refpoints_unsigmoid: Optional[Tensor] = None, # num_queries, bs, 2
|
||||||
|
# for memory
|
||||||
|
level_start_index: Optional[Tensor] = None, # num_levels
|
||||||
|
spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2
|
||||||
|
valid_ratios: Optional[Tensor] = None,
|
||||||
|
# for text
|
||||||
|
memory_text: Optional[Tensor] = None,
|
||||||
|
text_attention_mask: Optional[Tensor] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Input:
|
||||||
|
- tgt: nq, bs, d_model
|
||||||
|
- memory: hw, bs, d_model
|
||||||
|
- pos: hw, bs, d_model
|
||||||
|
- refpoints_unsigmoid: nq, bs, 2/4
|
||||||
|
- valid_ratios/spatial_shapes: bs, nlevel, 2
|
||||||
|
"""
|
||||||
|
output = tgt
|
||||||
|
|
||||||
|
intermediate = []
|
||||||
|
reference_points = refpoints_unsigmoid.sigmoid()
|
||||||
|
ref_points = [reference_points]
|
||||||
|
|
||||||
|
for layer_id, layer in enumerate(self.layers):
|
||||||
|
|
||||||
|
if reference_points.shape[-1] == 4:
|
||||||
|
reference_points_input = (
|
||||||
|
reference_points[:, :, None]
|
||||||
|
* torch.cat([valid_ratios, valid_ratios], -1)[None, :]
|
||||||
|
) # nq, bs, nlevel, 4
|
||||||
|
else:
|
||||||
|
assert reference_points.shape[-1] == 2
|
||||||
|
reference_points_input = reference_points[:, :, None] * valid_ratios[None, :]
|
||||||
|
query_sine_embed = gen_sineembed_for_position(
|
||||||
|
reference_points_input[:, :, 0, :]
|
||||||
|
) # nq, bs, 256*2
|
||||||
|
|
||||||
|
# conditional query
|
||||||
|
raw_query_pos = self.ref_point_head(query_sine_embed) # nq, bs, 256
|
||||||
|
pos_scale = self.query_scale(output) if self.query_scale is not None else 1
|
||||||
|
query_pos = pos_scale * raw_query_pos
|
||||||
|
# if os.environ.get("SHILONG_AMP_INFNAN_DEBUG") == '1':
|
||||||
|
# if query_pos.isnan().any() | query_pos.isinf().any():
|
||||||
|
# import ipdb; ipdb.set_trace()
|
||||||
|
|
||||||
|
# main process
|
||||||
|
output = layer(
|
||||||
|
tgt=output,
|
||||||
|
tgt_query_pos=query_pos,
|
||||||
|
tgt_query_sine_embed=query_sine_embed,
|
||||||
|
tgt_key_padding_mask=tgt_key_padding_mask,
|
||||||
|
tgt_reference_points=reference_points_input,
|
||||||
|
memory_text=memory_text,
|
||||||
|
text_attention_mask=text_attention_mask,
|
||||||
|
memory=memory,
|
||||||
|
memory_key_padding_mask=memory_key_padding_mask,
|
||||||
|
memory_level_start_index=level_start_index,
|
||||||
|
memory_spatial_shapes=spatial_shapes,
|
||||||
|
memory_pos=pos,
|
||||||
|
self_attn_mask=tgt_mask,
|
||||||
|
cross_attn_mask=memory_mask,
|
||||||
|
)
|
||||||
|
if output.isnan().any() | output.isinf().any():
|
||||||
|
print(f"output layer_id {layer_id} is nan")
|
||||||
|
try:
|
||||||
|
num_nan = output.isnan().sum().item()
|
||||||
|
num_inf = output.isinf().sum().item()
|
||||||
|
print(f"num_nan {num_nan}, num_inf {num_inf}")
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
# if os.environ.get("SHILONG_AMP_INFNAN_DEBUG") == '1':
|
||||||
|
# import ipdb; ipdb.set_trace()
|
||||||
|
|
||||||
|
# iter update
|
||||||
|
if self.bbox_embed is not None:
|
||||||
|
# box_holder = self.bbox_embed(output)
|
||||||
|
# box_holder[..., :self.query_dim] += inverse_sigmoid(reference_points)
|
||||||
|
# new_reference_points = box_holder[..., :self.query_dim].sigmoid()
|
||||||
|
|
||||||
|
reference_before_sigmoid = inverse_sigmoid(reference_points)
|
||||||
|
delta_unsig = self.bbox_embed[layer_id](output)
|
||||||
|
outputs_unsig = delta_unsig + reference_before_sigmoid
|
||||||
|
new_reference_points = outputs_unsig.sigmoid()
|
||||||
|
|
||||||
|
reference_points = new_reference_points.detach()
|
||||||
|
# if layer_id != self.num_layers - 1:
|
||||||
|
ref_points.append(new_reference_points)
|
||||||
|
|
||||||
|
intermediate.append(self.norm(output))
|
||||||
|
|
||||||
|
return [
|
||||||
|
[itm_out.transpose(0, 1) for itm_out in intermediate],
|
||||||
|
[itm_refpoint.transpose(0, 1) for itm_refpoint in ref_points],
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class DeformableTransformerEncoderLayer(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
d_model=256,
|
||||||
|
d_ffn=1024,
|
||||||
|
dropout=0.1,
|
||||||
|
activation="relu",
|
||||||
|
n_levels=4,
|
||||||
|
n_heads=8,
|
||||||
|
n_points=4,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# self attention
|
||||||
|
self.self_attn = MSDeformAttn(
|
||||||
|
embed_dim=d_model,
|
||||||
|
num_levels=n_levels,
|
||||||
|
num_heads=n_heads,
|
||||||
|
num_points=n_points,
|
||||||
|
batch_first=True,
|
||||||
|
)
|
||||||
|
self.dropout1 = nn.Dropout(dropout)
|
||||||
|
self.norm1 = nn.LayerNorm(d_model)
|
||||||
|
|
||||||
|
# ffn
|
||||||
|
self.linear1 = nn.Linear(d_model, d_ffn)
|
||||||
|
self.activation = _get_activation_fn(activation, d_model=d_ffn)
|
||||||
|
self.dropout2 = nn.Dropout(dropout)
|
||||||
|
self.linear2 = nn.Linear(d_ffn, d_model)
|
||||||
|
self.dropout3 = nn.Dropout(dropout)
|
||||||
|
self.norm2 = nn.LayerNorm(d_model)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def with_pos_embed(tensor, pos):
|
||||||
|
return tensor if pos is None else tensor + pos
|
||||||
|
|
||||||
|
def forward_ffn(self, src):
|
||||||
|
src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))
|
||||||
|
src = src + self.dropout3(src2)
|
||||||
|
src = self.norm2(src)
|
||||||
|
return src
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, src, pos, reference_points, spatial_shapes, level_start_index, key_padding_mask=None
|
||||||
|
):
|
||||||
|
# self attention
|
||||||
|
# import ipdb; ipdb.set_trace()
|
||||||
|
src2 = self.self_attn(
|
||||||
|
query=self.with_pos_embed(src, pos),
|
||||||
|
reference_points=reference_points,
|
||||||
|
value=src,
|
||||||
|
spatial_shapes=spatial_shapes,
|
||||||
|
level_start_index=level_start_index,
|
||||||
|
key_padding_mask=key_padding_mask,
|
||||||
|
)
|
||||||
|
src = src + self.dropout1(src2)
|
||||||
|
src = self.norm1(src)
|
||||||
|
|
||||||
|
# ffn
|
||||||
|
src = self.forward_ffn(src)
|
||||||
|
|
||||||
|
return src
|
||||||
|
|
||||||
|
|
||||||
|
class DeformableTransformerDecoderLayer(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
d_model=256,
|
||||||
|
d_ffn=1024,
|
||||||
|
dropout=0.1,
|
||||||
|
activation="relu",
|
||||||
|
n_levels=4,
|
||||||
|
n_heads=8,
|
||||||
|
n_points=4,
|
||||||
|
use_text_feat_guide=False,
|
||||||
|
use_text_cross_attention=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# cross attention
|
||||||
|
self.cross_attn = MSDeformAttn(
|
||||||
|
embed_dim=d_model,
|
||||||
|
num_levels=n_levels,
|
||||||
|
num_heads=n_heads,
|
||||||
|
num_points=n_points,
|
||||||
|
batch_first=True,
|
||||||
|
)
|
||||||
|
self.dropout1 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
|
||||||
|
self.norm1 = nn.LayerNorm(d_model)
|
||||||
|
|
||||||
|
# cross attention text
|
||||||
|
if use_text_cross_attention:
|
||||||
|
self.ca_text = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
|
||||||
|
self.catext_dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
|
||||||
|
self.catext_norm = nn.LayerNorm(d_model)
|
||||||
|
|
||||||
|
# self attention
|
||||||
|
self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
|
||||||
|
self.dropout2 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
|
||||||
|
self.norm2 = nn.LayerNorm(d_model)
|
||||||
|
|
||||||
|
# ffn
|
||||||
|
self.linear1 = nn.Linear(d_model, d_ffn)
|
||||||
|
self.activation = _get_activation_fn(activation, d_model=d_ffn, batch_dim=1)
|
||||||
|
self.dropout3 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
|
||||||
|
self.linear2 = nn.Linear(d_ffn, d_model)
|
||||||
|
self.dropout4 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
|
||||||
|
self.norm3 = nn.LayerNorm(d_model)
|
||||||
|
|
||||||
|
self.key_aware_proj = None
|
||||||
|
self.use_text_feat_guide = use_text_feat_guide
|
||||||
|
assert not use_text_feat_guide
|
||||||
|
self.use_text_cross_attention = use_text_cross_attention
|
||||||
|
|
||||||
|
def rm_self_attn_modules(self):
|
||||||
|
self.self_attn = None
|
||||||
|
self.dropout2 = None
|
||||||
|
self.norm2 = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def with_pos_embed(tensor, pos):
|
||||||
|
return tensor if pos is None else tensor + pos
|
||||||
|
|
||||||
|
def forward_ffn(self, tgt):
|
||||||
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
|
tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt))))
|
||||||
|
tgt = tgt + self.dropout4(tgt2)
|
||||||
|
tgt = self.norm3(tgt)
|
||||||
|
return tgt
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
# for tgt
|
||||||
|
tgt: Optional[Tensor], # nq, bs, d_model
|
||||||
|
tgt_query_pos: Optional[Tensor] = None, # pos for query. MLP(Sine(pos))
|
||||||
|
tgt_query_sine_embed: Optional[Tensor] = None, # pos for query. Sine(pos)
|
||||||
|
tgt_key_padding_mask: Optional[Tensor] = None,
|
||||||
|
tgt_reference_points: Optional[Tensor] = None, # nq, bs, 4
|
||||||
|
memory_text: Optional[Tensor] = None, # bs, num_token, d_model
|
||||||
|
text_attention_mask: Optional[Tensor] = None, # bs, num_token
|
||||||
|
# for memory
|
||||||
|
memory: Optional[Tensor] = None, # hw, bs, d_model
|
||||||
|
memory_key_padding_mask: Optional[Tensor] = None,
|
||||||
|
memory_level_start_index: Optional[Tensor] = None, # num_levels
|
||||||
|
memory_spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2
|
||||||
|
memory_pos: Optional[Tensor] = None, # pos for memory
|
||||||
|
# sa
|
||||||
|
self_attn_mask: Optional[Tensor] = None, # mask used for self-attention
|
||||||
|
cross_attn_mask: Optional[Tensor] = None, # mask used for cross-attention
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Input:
|
||||||
|
- tgt/tgt_query_pos: nq, bs, d_model
|
||||||
|
-
|
||||||
|
"""
|
||||||
|
assert cross_attn_mask is None
|
||||||
|
|
||||||
|
# self attention
|
||||||
|
if self.self_attn is not None:
|
||||||
|
# import ipdb; ipdb.set_trace()
|
||||||
|
q = k = self.with_pos_embed(tgt, tgt_query_pos)
|
||||||
|
tgt2 = self.self_attn(q, k, tgt, attn_mask=self_attn_mask)[0]
|
||||||
|
tgt = tgt + self.dropout2(tgt2)
|
||||||
|
tgt = self.norm2(tgt)
|
||||||
|
|
||||||
|
if self.use_text_cross_attention:
|
||||||
|
tgt2 = self.ca_text(
|
||||||
|
self.with_pos_embed(tgt, tgt_query_pos),
|
||||||
|
memory_text.transpose(0, 1),
|
||||||
|
memory_text.transpose(0, 1),
|
||||||
|
key_padding_mask=text_attention_mask,
|
||||||
|
)[0]
|
||||||
|
tgt = tgt + self.catext_dropout(tgt2)
|
||||||
|
tgt = self.catext_norm(tgt)
|
||||||
|
|
||||||
|
tgt2 = self.cross_attn(
|
||||||
|
query=self.with_pos_embed(tgt, tgt_query_pos).transpose(0, 1),
|
||||||
|
reference_points=tgt_reference_points.transpose(0, 1).contiguous(),
|
||||||
|
value=memory.transpose(0, 1),
|
||||||
|
spatial_shapes=memory_spatial_shapes,
|
||||||
|
level_start_index=memory_level_start_index,
|
||||||
|
key_padding_mask=memory_key_padding_mask,
|
||||||
|
).transpose(0, 1)
|
||||||
|
tgt = tgt + self.dropout1(tgt2)
|
||||||
|
tgt = self.norm1(tgt)
|
||||||
|
|
||||||
|
# ffn
|
||||||
|
tgt = self.forward_ffn(tgt)
|
||||||
|
|
||||||
|
return tgt
|
||||||
|
|
||||||
|
|
||||||
|
def build_transformer(args):
|
||||||
|
return Transformer(
|
||||||
|
d_model=args.hidden_dim,
|
||||||
|
dropout=args.dropout,
|
||||||
|
nhead=args.nheads,
|
||||||
|
num_queries=args.num_queries,
|
||||||
|
dim_feedforward=args.dim_feedforward,
|
||||||
|
num_encoder_layers=args.enc_layers,
|
||||||
|
num_decoder_layers=args.dec_layers,
|
||||||
|
normalize_before=args.pre_norm,
|
||||||
|
return_intermediate_dec=True,
|
||||||
|
query_dim=args.query_dim,
|
||||||
|
activation=args.transformer_activation,
|
||||||
|
num_patterns=args.num_patterns,
|
||||||
|
num_feature_levels=args.num_feature_levels,
|
||||||
|
enc_n_points=args.enc_n_points,
|
||||||
|
dec_n_points=args.dec_n_points,
|
||||||
|
learnable_tgt_init=True,
|
||||||
|
# two stage
|
||||||
|
two_stage_type=args.two_stage_type, # ['no', 'standard', 'early']
|
||||||
|
embed_init_tgt=args.embed_init_tgt,
|
||||||
|
use_text_enhancer=args.use_text_enhancer,
|
||||||
|
use_fusion_layer=args.use_fusion_layer,
|
||||||
|
use_checkpoint=args.use_checkpoint,
|
||||||
|
use_transformer_ckpt=args.use_transformer_ckpt,
|
||||||
|
use_text_cross_attention=args.use_text_cross_attention,
|
||||||
|
text_dropout=args.text_dropout,
|
||||||
|
fusion_dropout=args.fusion_dropout,
|
||||||
|
fusion_droppath=args.fusion_droppath,
|
||||||
|
)
|
@ -0,0 +1,123 @@
|
|||||||
|
# ------------------------------------------------------------------------
|
||||||
|
# Grounding DINO
|
||||||
|
# url: https://github.com/IDEA-Research/GroundingDINO
|
||||||
|
# Copyright (c) 2023 IDEA. All Rights Reserved.
|
||||||
|
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
||||||
|
# ------------------------------------------------------------------------
|
||||||
|
# Copyright (c) Aishwarya Kamath & Nicolas Carion. Licensed under the Apache License 2.0. All Rights Reserved
|
||||||
|
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||||
|
"""
|
||||||
|
DETR Transformer class.
|
||||||
|
|
||||||
|
Copy-paste from torch.nn.Transformer with modifications:
|
||||||
|
* positional encodings are passed in MHattention
|
||||||
|
* extra LN at the end of encoder is removed
|
||||||
|
* decoder returns a stack of activations from all decoding layers
|
||||||
|
"""
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import Tensor, nn
|
||||||
|
|
||||||
|
from .utils import (
|
||||||
|
MLP,
|
||||||
|
_get_activation_fn,
|
||||||
|
_get_clones,
|
||||||
|
gen_encoder_output_proposals,
|
||||||
|
gen_sineembed_for_position,
|
||||||
|
sigmoid_focal_loss,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TextTransformer(nn.Module):
|
||||||
|
def __init__(self, num_layers, d_model=256, nheads=8, dim_feedforward=2048, dropout=0.1):
|
||||||
|
super().__init__()
|
||||||
|
self.num_layers = num_layers
|
||||||
|
self.d_model = d_model
|
||||||
|
self.nheads = nheads
|
||||||
|
self.dim_feedforward = dim_feedforward
|
||||||
|
self.norm = None
|
||||||
|
|
||||||
|
single_encoder_layer = TransformerEncoderLayer(
|
||||||
|
d_model=d_model, nhead=nheads, dim_feedforward=dim_feedforward, dropout=dropout
|
||||||
|
)
|
||||||
|
self.layers = _get_clones(single_encoder_layer, num_layers)
|
||||||
|
|
||||||
|
def forward(self, memory_text: torch.Tensor, text_attention_mask: torch.Tensor):
|
||||||
|
"""
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text_attention_mask: bs, num_token
|
||||||
|
memory_text: bs, num_token, d_model
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: _description_
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
output: bs, num_token, d_model
|
||||||
|
"""
|
||||||
|
|
||||||
|
output = memory_text.transpose(0, 1)
|
||||||
|
|
||||||
|
for layer in self.layers:
|
||||||
|
output = layer(output, src_key_padding_mask=text_attention_mask)
|
||||||
|
|
||||||
|
if self.norm is not None:
|
||||||
|
output = self.norm(output)
|
||||||
|
|
||||||
|
return output.transpose(0, 1)
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerEncoderLayer(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
d_model,
|
||||||
|
nhead,
|
||||||
|
dim_feedforward=2048,
|
||||||
|
dropout=0.1,
|
||||||
|
activation="relu",
|
||||||
|
normalize_before=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
||||||
|
# Implementation of Feedforward model
|
||||||
|
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
||||||
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
||||||
|
|
||||||
|
self.norm1 = nn.LayerNorm(d_model)
|
||||||
|
self.norm2 = nn.LayerNorm(d_model)
|
||||||
|
self.dropout1 = nn.Dropout(dropout)
|
||||||
|
self.dropout2 = nn.Dropout(dropout)
|
||||||
|
|
||||||
|
self.activation = _get_activation_fn(activation)
|
||||||
|
self.normalize_before = normalize_before
|
||||||
|
self.nhead = nhead
|
||||||
|
|
||||||
|
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
||||||
|
return tensor if pos is None else tensor + pos
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
src,
|
||||||
|
src_mask: Optional[Tensor] = None,
|
||||||
|
src_key_padding_mask: Optional[Tensor] = None,
|
||||||
|
pos: Optional[Tensor] = None,
|
||||||
|
):
|
||||||
|
# repeat attn mask
|
||||||
|
if src_mask.dim() == 3 and src_mask.shape[0] == src.shape[1]:
|
||||||
|
# bs, num_q, num_k
|
||||||
|
src_mask = src_mask.repeat(self.nhead, 1, 1)
|
||||||
|
|
||||||
|
q = k = self.with_pos_embed(src, pos)
|
||||||
|
|
||||||
|
src2 = self.self_attn(q, k, value=src, attn_mask=src_mask)[0]
|
||||||
|
|
||||||
|
# src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
|
||||||
|
src = src + self.dropout1(src2)
|
||||||
|
src = self.norm1(src)
|
||||||
|
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
|
||||||
|
src = src + self.dropout2(src2)
|
||||||
|
src = self.norm2(src)
|
||||||
|
return src
|
@ -0,0 +1,268 @@
|
|||||||
|
# ------------------------------------------------------------------------
|
||||||
|
# Grounding DINO
|
||||||
|
# url: https://github.com/IDEA-Research/GroundingDINO
|
||||||
|
# Copyright (c) 2023 IDEA. All Rights Reserved.
|
||||||
|
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
||||||
|
# ------------------------------------------------------------------------
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import math
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import Tensor, nn
|
||||||
|
|
||||||
|
|
||||||
|
def _get_clones(module, N, layer_share=False):
|
||||||
|
# import ipdb; ipdb.set_trace()
|
||||||
|
if layer_share:
|
||||||
|
return nn.ModuleList([module for i in range(N)])
|
||||||
|
else:
|
||||||
|
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
||||||
|
|
||||||
|
|
||||||
|
def get_sine_pos_embed(
|
||||||
|
pos_tensor: torch.Tensor,
|
||||||
|
num_pos_feats: int = 128,
|
||||||
|
temperature: int = 10000,
|
||||||
|
exchange_xy: bool = True,
|
||||||
|
):
|
||||||
|
"""generate sine position embedding from a position tensor
|
||||||
|
Args:
|
||||||
|
pos_tensor (torch.Tensor): shape: [..., n].
|
||||||
|
num_pos_feats (int): projected shape for each float in the tensor.
|
||||||
|
temperature (int): temperature in the sine/cosine function.
|
||||||
|
exchange_xy (bool, optional): exchange pos x and pos y. \
|
||||||
|
For example, input tensor is [x,y], the results will be [pos(y), pos(x)]. Defaults to True.
|
||||||
|
Returns:
|
||||||
|
pos_embed (torch.Tensor): shape: [..., n*num_pos_feats].
|
||||||
|
"""
|
||||||
|
scale = 2 * math.pi
|
||||||
|
dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=pos_tensor.device)
|
||||||
|
dim_t = temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / num_pos_feats)
|
||||||
|
|
||||||
|
def sine_func(x: torch.Tensor):
|
||||||
|
sin_x = x * scale / dim_t
|
||||||
|
sin_x = torch.stack((sin_x[..., 0::2].sin(), sin_x[..., 1::2].cos()), dim=3).flatten(2)
|
||||||
|
return sin_x
|
||||||
|
|
||||||
|
pos_res = [sine_func(x) for x in pos_tensor.split([1] * pos_tensor.shape[-1], dim=-1)]
|
||||||
|
if exchange_xy:
|
||||||
|
pos_res[0], pos_res[1] = pos_res[1], pos_res[0]
|
||||||
|
pos_res = torch.cat(pos_res, dim=-1)
|
||||||
|
return pos_res
|
||||||
|
|
||||||
|
|
||||||
|
def gen_encoder_output_proposals(
|
||||||
|
memory: Tensor, memory_padding_mask: Tensor, spatial_shapes: Tensor, learnedwh=None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Input:
|
||||||
|
- memory: bs, \sum{hw}, d_model
|
||||||
|
- memory_padding_mask: bs, \sum{hw}
|
||||||
|
- spatial_shapes: nlevel, 2
|
||||||
|
- learnedwh: 2
|
||||||
|
Output:
|
||||||
|
- output_memory: bs, \sum{hw}, d_model
|
||||||
|
- output_proposals: bs, \sum{hw}, 4
|
||||||
|
"""
|
||||||
|
N_, S_, C_ = memory.shape
|
||||||
|
proposals = []
|
||||||
|
_cur = 0
|
||||||
|
for lvl, (H_, W_) in enumerate(spatial_shapes):
|
||||||
|
mask_flatten_ = memory_padding_mask[:, _cur : (_cur + H_ * W_)].view(N_, H_, W_, 1)
|
||||||
|
valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1)
|
||||||
|
valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1)
|
||||||
|
|
||||||
|
# import ipdb; ipdb.set_trace()
|
||||||
|
|
||||||
|
grid_y, grid_x = torch.meshgrid(
|
||||||
|
torch.linspace(0, H_ - 1, H_, dtype=torch.float32, device=memory.device),
|
||||||
|
torch.linspace(0, W_ - 1, W_, dtype=torch.float32, device=memory.device),
|
||||||
|
)
|
||||||
|
grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) # H_, W_, 2
|
||||||
|
|
||||||
|
scale = torch.cat([valid_W.unsqueeze(-1), valid_H.unsqueeze(-1)], 1).view(N_, 1, 1, 2)
|
||||||
|
grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale
|
||||||
|
|
||||||
|
if learnedwh is not None:
|
||||||
|
# import ipdb; ipdb.set_trace()
|
||||||
|
wh = torch.ones_like(grid) * learnedwh.sigmoid() * (2.0**lvl)
|
||||||
|
else:
|
||||||
|
wh = torch.ones_like(grid) * 0.05 * (2.0**lvl)
|
||||||
|
|
||||||
|
# scale = torch.cat([W_[None].unsqueeze(-1), H_[None].unsqueeze(-1)], 1).view(1, 1, 1, 2).repeat(N_, 1, 1, 1)
|
||||||
|
# grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale
|
||||||
|
# wh = torch.ones_like(grid) / scale
|
||||||
|
proposal = torch.cat((grid, wh), -1).view(N_, -1, 4)
|
||||||
|
proposals.append(proposal)
|
||||||
|
_cur += H_ * W_
|
||||||
|
# import ipdb; ipdb.set_trace()
|
||||||
|
output_proposals = torch.cat(proposals, 1)
|
||||||
|
output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(
|
||||||
|
-1, keepdim=True
|
||||||
|
)
|
||||||
|
output_proposals = torch.log(output_proposals / (1 - output_proposals)) # unsigmoid
|
||||||
|
output_proposals = output_proposals.masked_fill(memory_padding_mask.unsqueeze(-1), float("inf"))
|
||||||
|
output_proposals = output_proposals.masked_fill(~output_proposals_valid, float("inf"))
|
||||||
|
|
||||||
|
output_memory = memory
|
||||||
|
output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float(0))
|
||||||
|
output_memory = output_memory.masked_fill(~output_proposals_valid, float(0))
|
||||||
|
|
||||||
|
# output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float('inf'))
|
||||||
|
# output_memory = output_memory.masked_fill(~output_proposals_valid, float('inf'))
|
||||||
|
|
||||||
|
return output_memory, output_proposals
|
||||||
|
|
||||||
|
|
||||||
|
class RandomBoxPerturber:
|
||||||
|
def __init__(
|
||||||
|
self, x_noise_scale=0.2, y_noise_scale=0.2, w_noise_scale=0.2, h_noise_scale=0.2
|
||||||
|
) -> None:
|
||||||
|
self.noise_scale = torch.Tensor(
|
||||||
|
[x_noise_scale, y_noise_scale, w_noise_scale, h_noise_scale]
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, refanchors: Tensor) -> Tensor:
|
||||||
|
nq, bs, query_dim = refanchors.shape
|
||||||
|
device = refanchors.device
|
||||||
|
|
||||||
|
noise_raw = torch.rand_like(refanchors)
|
||||||
|
noise_scale = self.noise_scale.to(device)[:query_dim]
|
||||||
|
|
||||||
|
new_refanchors = refanchors * (1 + (noise_raw - 0.5) * noise_scale)
|
||||||
|
return new_refanchors.clamp_(0, 1)
|
||||||
|
|
||||||
|
|
||||||
|
def sigmoid_focal_loss(
|
||||||
|
inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2, no_reduction=False
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
|
||||||
|
Args:
|
||||||
|
inputs: A float tensor of arbitrary shape.
|
||||||
|
The predictions for each example.
|
||||||
|
targets: A float tensor with the same shape as inputs. Stores the binary
|
||||||
|
classification label for each element in inputs
|
||||||
|
(0 for the negative class and 1 for the positive class).
|
||||||
|
alpha: (optional) Weighting factor in range (0,1) to balance
|
||||||
|
positive vs negative examples. Default = -1 (no weighting).
|
||||||
|
gamma: Exponent of the modulating factor (1 - p_t) to
|
||||||
|
balance easy vs hard examples.
|
||||||
|
Returns:
|
||||||
|
Loss tensor
|
||||||
|
"""
|
||||||
|
prob = inputs.sigmoid()
|
||||||
|
ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
|
||||||
|
p_t = prob * targets + (1 - prob) * (1 - targets)
|
||||||
|
loss = ce_loss * ((1 - p_t) ** gamma)
|
||||||
|
|
||||||
|
if alpha >= 0:
|
||||||
|
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
|
||||||
|
loss = alpha_t * loss
|
||||||
|
|
||||||
|
if no_reduction:
|
||||||
|
return loss
|
||||||
|
|
||||||
|
return loss.mean(1).sum() / num_boxes
|
||||||
|
|
||||||
|
|
||||||
|
class MLP(nn.Module):
|
||||||
|
"""Very simple multi-layer perceptron (also called FFN)"""
|
||||||
|
|
||||||
|
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
|
||||||
|
super().__init__()
|
||||||
|
self.num_layers = num_layers
|
||||||
|
h = [hidden_dim] * (num_layers - 1)
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
for i, layer in enumerate(self.layers):
|
||||||
|
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def _get_activation_fn(activation, d_model=256, batch_dim=0):
|
||||||
|
"""Return an activation function given a string"""
|
||||||
|
if activation == "relu":
|
||||||
|
return F.relu
|
||||||
|
if activation == "gelu":
|
||||||
|
return F.gelu
|
||||||
|
if activation == "glu":
|
||||||
|
return F.glu
|
||||||
|
if activation == "prelu":
|
||||||
|
return nn.PReLU()
|
||||||
|
if activation == "selu":
|
||||||
|
return F.selu
|
||||||
|
|
||||||
|
raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
|
||||||
|
|
||||||
|
|
||||||
|
def gen_sineembed_for_position(pos_tensor):
|
||||||
|
# n_query, bs, _ = pos_tensor.size()
|
||||||
|
# sineembed_tensor = torch.zeros(n_query, bs, 256)
|
||||||
|
scale = 2 * math.pi
|
||||||
|
dim_t = torch.arange(128, dtype=torch.float32, device=pos_tensor.device)
|
||||||
|
dim_t = 10000 ** (2 * (torch.div(dim_t, 2, rounding_mode='floor')) / 128)
|
||||||
|
x_embed = pos_tensor[:, :, 0] * scale
|
||||||
|
y_embed = pos_tensor[:, :, 1] * scale
|
||||||
|
pos_x = x_embed[:, :, None] / dim_t
|
||||||
|
pos_y = y_embed[:, :, None] / dim_t
|
||||||
|
pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2)
|
||||||
|
pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2)
|
||||||
|
if pos_tensor.size(-1) == 2:
|
||||||
|
pos = torch.cat((pos_y, pos_x), dim=2)
|
||||||
|
elif pos_tensor.size(-1) == 4:
|
||||||
|
w_embed = pos_tensor[:, :, 2] * scale
|
||||||
|
pos_w = w_embed[:, :, None] / dim_t
|
||||||
|
pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2)
|
||||||
|
|
||||||
|
h_embed = pos_tensor[:, :, 3] * scale
|
||||||
|
pos_h = h_embed[:, :, None] / dim_t
|
||||||
|
pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2)
|
||||||
|
|
||||||
|
pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2)
|
||||||
|
else:
|
||||||
|
raise ValueError("Unknown pos_tensor shape(-1):{}".format(pos_tensor.size(-1)))
|
||||||
|
return pos
|
||||||
|
|
||||||
|
|
||||||
|
class ContrastiveEmbed(nn.Module):
|
||||||
|
def __init__(self, max_text_len=256):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
max_text_len: max length of text.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.max_text_len = max_text_len
|
||||||
|
|
||||||
|
def forward(self, x, text_dict):
|
||||||
|
"""_summary_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (_type_): _description_
|
||||||
|
text_dict (_type_): _description_
|
||||||
|
{
|
||||||
|
'encoded_text': encoded_text, # bs, 195, d_model
|
||||||
|
'text_token_mask': text_token_mask, # bs, 195
|
||||||
|
# True for used tokens. False for padding tokens
|
||||||
|
}
|
||||||
|
Returns:
|
||||||
|
_type_: _description_
|
||||||
|
"""
|
||||||
|
assert isinstance(text_dict, dict)
|
||||||
|
|
||||||
|
y = text_dict["encoded_text"]
|
||||||
|
text_token_mask = text_dict["text_token_mask"]
|
||||||
|
|
||||||
|
res = x @ y.transpose(-1, -2)
|
||||||
|
res.masked_fill_(~text_token_mask[:, None, :], float("-inf"))
|
||||||
|
|
||||||
|
# padding to max_text_len
|
||||||
|
new_res = torch.full((*res.shape[:-1], self.max_text_len), float("-inf"), device=res.device)
|
||||||
|
new_res[..., : res.shape[-1]] = res
|
||||||
|
|
||||||
|
return new_res
|
@ -0,0 +1,18 @@
|
|||||||
|
# ------------------------------------------------------------------------
|
||||||
|
# Grounding DINO
|
||||||
|
# url: https://github.com/IDEA-Research/GroundingDINO
|
||||||
|
# Copyright (c) 2023 IDEA. All Rights Reserved.
|
||||||
|
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
||||||
|
# ------------------------------------------------------------------------
|
||||||
|
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||||
|
from .GroundingDINO import build_groundingdino
|
||||||
|
|
||||||
|
|
||||||
|
def build_model(args):
|
||||||
|
# we use register to maintain models from catdet6 on.
|
||||||
|
from .registry import MODULE_BUILD_FUNCS
|
||||||
|
|
||||||
|
assert args.modelname in MODULE_BUILD_FUNCS._module_dict
|
||||||
|
build_func = MODULE_BUILD_FUNCS.get(args.modelname)
|
||||||
|
model = build_func(args)
|
||||||
|
return model
|
@ -0,0 +1,66 @@
|
|||||||
|
# ------------------------------------------------------------------------
|
||||||
|
# Grounding DINO
|
||||||
|
# url: https://github.com/IDEA-Research/GroundingDINO
|
||||||
|
# Copyright (c) 2023 IDEA. All Rights Reserved.
|
||||||
|
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
||||||
|
# ------------------------------------------------------------------------
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# @Author: Yihao Chen
|
||||||
|
# @Date: 2021-08-16 16:03:17
|
||||||
|
# @Last Modified by: Shilong Liu
|
||||||
|
# @Last Modified time: 2022-01-23 15:26
|
||||||
|
# modified from mmcv
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
|
||||||
|
class Registry(object):
|
||||||
|
def __init__(self, name):
|
||||||
|
self._name = name
|
||||||
|
self._module_dict = dict()
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
format_str = self.__class__.__name__ + "(name={}, items={})".format(
|
||||||
|
self._name, list(self._module_dict.keys())
|
||||||
|
)
|
||||||
|
return format_str
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self._module_dict)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self):
|
||||||
|
return self._name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def module_dict(self):
|
||||||
|
return self._module_dict
|
||||||
|
|
||||||
|
def get(self, key):
|
||||||
|
return self._module_dict.get(key, None)
|
||||||
|
|
||||||
|
def registe_with_name(self, module_name=None, force=False):
|
||||||
|
return partial(self.register, module_name=module_name, force=force)
|
||||||
|
|
||||||
|
def register(self, module_build_function, module_name=None, force=False):
|
||||||
|
"""Register a module build function.
|
||||||
|
Args:
|
||||||
|
module (:obj:`nn.Module`): Module to be registered.
|
||||||
|
"""
|
||||||
|
if not inspect.isfunction(module_build_function):
|
||||||
|
raise TypeError(
|
||||||
|
"module_build_function must be a function, but got {}".format(
|
||||||
|
type(module_build_function)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if module_name is None:
|
||||||
|
module_name = module_build_function.__name__
|
||||||
|
if not force and module_name in self._module_dict:
|
||||||
|
raise KeyError("{} is already registered in {}".format(module_name, self.name))
|
||||||
|
self._module_dict[module_name] = module_build_function
|
||||||
|
|
||||||
|
return module_build_function
|
||||||
|
|
||||||
|
|
||||||
|
MODULE_BUILD_FUNCS = Registry("model build functions")
|
@ -0,0 +1 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
@ -0,0 +1,140 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||||
|
"""
|
||||||
|
Utilities for bounding box manipulation and GIoU.
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
from torchvision.ops.boxes import box_area
|
||||||
|
|
||||||
|
|
||||||
|
def box_cxcywh_to_xyxy(x):
|
||||||
|
x_c, y_c, w, h = x.unbind(-1)
|
||||||
|
b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
|
||||||
|
return torch.stack(b, dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
def box_xyxy_to_cxcywh(x):
|
||||||
|
x0, y0, x1, y1 = x.unbind(-1)
|
||||||
|
b = [(x0 + x1) / 2, (y0 + y1) / 2, (x1 - x0), (y1 - y0)]
|
||||||
|
return torch.stack(b, dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
# modified from torchvision to also return the union
|
||||||
|
def box_iou(boxes1, boxes2):
|
||||||
|
area1 = box_area(boxes1)
|
||||||
|
area2 = box_area(boxes2)
|
||||||
|
|
||||||
|
# import ipdb; ipdb.set_trace()
|
||||||
|
lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
|
||||||
|
rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
|
||||||
|
|
||||||
|
wh = (rb - lt).clamp(min=0) # [N,M,2]
|
||||||
|
inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
|
||||||
|
|
||||||
|
union = area1[:, None] + area2 - inter
|
||||||
|
|
||||||
|
iou = inter / (union + 1e-6)
|
||||||
|
return iou, union
|
||||||
|
|
||||||
|
|
||||||
|
def generalized_box_iou(boxes1, boxes2):
|
||||||
|
"""
|
||||||
|
Generalized IoU from https://giou.stanford.edu/
|
||||||
|
|
||||||
|
The boxes should be in [x0, y0, x1, y1] format
|
||||||
|
|
||||||
|
Returns a [N, M] pairwise matrix, where N = len(boxes1)
|
||||||
|
and M = len(boxes2)
|
||||||
|
"""
|
||||||
|
# degenerate boxes gives inf / nan results
|
||||||
|
# so do an early check
|
||||||
|
assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
|
||||||
|
assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
|
||||||
|
# except:
|
||||||
|
# import ipdb; ipdb.set_trace()
|
||||||
|
iou, union = box_iou(boxes1, boxes2)
|
||||||
|
|
||||||
|
lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
|
||||||
|
rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
|
||||||
|
|
||||||
|
wh = (rb - lt).clamp(min=0) # [N,M,2]
|
||||||
|
area = wh[:, :, 0] * wh[:, :, 1]
|
||||||
|
|
||||||
|
return iou - (area - union) / (area + 1e-6)
|
||||||
|
|
||||||
|
|
||||||
|
# modified from torchvision to also return the union
|
||||||
|
def box_iou_pairwise(boxes1, boxes2):
|
||||||
|
area1 = box_area(boxes1)
|
||||||
|
area2 = box_area(boxes2)
|
||||||
|
|
||||||
|
lt = torch.max(boxes1[:, :2], boxes2[:, :2]) # [N,2]
|
||||||
|
rb = torch.min(boxes1[:, 2:], boxes2[:, 2:]) # [N,2]
|
||||||
|
|
||||||
|
wh = (rb - lt).clamp(min=0) # [N,2]
|
||||||
|
inter = wh[:, 0] * wh[:, 1] # [N]
|
||||||
|
|
||||||
|
union = area1 + area2 - inter
|
||||||
|
|
||||||
|
iou = inter / union
|
||||||
|
return iou, union
|
||||||
|
|
||||||
|
|
||||||
|
def generalized_box_iou_pairwise(boxes1, boxes2):
|
||||||
|
"""
|
||||||
|
Generalized IoU from https://giou.stanford.edu/
|
||||||
|
|
||||||
|
Input:
|
||||||
|
- boxes1, boxes2: N,4
|
||||||
|
Output:
|
||||||
|
- giou: N, 4
|
||||||
|
"""
|
||||||
|
# degenerate boxes gives inf / nan results
|
||||||
|
# so do an early check
|
||||||
|
assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
|
||||||
|
assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
|
||||||
|
assert boxes1.shape == boxes2.shape
|
||||||
|
iou, union = box_iou_pairwise(boxes1, boxes2) # N, 4
|
||||||
|
|
||||||
|
lt = torch.min(boxes1[:, :2], boxes2[:, :2])
|
||||||
|
rb = torch.max(boxes1[:, 2:], boxes2[:, 2:])
|
||||||
|
|
||||||
|
wh = (rb - lt).clamp(min=0) # [N,2]
|
||||||
|
area = wh[:, 0] * wh[:, 1]
|
||||||
|
|
||||||
|
return iou - (area - union) / area
|
||||||
|
|
||||||
|
|
||||||
|
def masks_to_boxes(masks):
|
||||||
|
"""Compute the bounding boxes around the provided masks
|
||||||
|
|
||||||
|
The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions.
|
||||||
|
|
||||||
|
Returns a [N, 4] tensors, with the boxes in xyxy format
|
||||||
|
"""
|
||||||
|
if masks.numel() == 0:
|
||||||
|
return torch.zeros((0, 4), device=masks.device)
|
||||||
|
|
||||||
|
h, w = masks.shape[-2:]
|
||||||
|
|
||||||
|
y = torch.arange(0, h, dtype=torch.float)
|
||||||
|
x = torch.arange(0, w, dtype=torch.float)
|
||||||
|
y, x = torch.meshgrid(y, x)
|
||||||
|
|
||||||
|
x_mask = masks * x.unsqueeze(0)
|
||||||
|
x_max = x_mask.flatten(1).max(-1)[0]
|
||||||
|
x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
|
||||||
|
|
||||||
|
y_mask = masks * y.unsqueeze(0)
|
||||||
|
y_max = y_mask.flatten(1).max(-1)[0]
|
||||||
|
y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
|
||||||
|
|
||||||
|
return torch.stack([x_min, y_min, x_max, y_max], 1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
x = torch.rand(5, 4)
|
||||||
|
y = torch.rand(3, 4)
|
||||||
|
iou, union = box_iou(x, y)
|
||||||
|
import ipdb
|
||||||
|
|
||||||
|
ipdb.set_trace()
|
@ -0,0 +1,29 @@
|
|||||||
|
from transformers import AutoTokenizer, BertModel, BertTokenizer, RobertaModel, RobertaTokenizerFast
|
||||||
|
import os
|
||||||
|
|
||||||
|
def get_tokenlizer(text_encoder_type):
|
||||||
|
if not isinstance(text_encoder_type, str):
|
||||||
|
# print("text_encoder_type is not a str")
|
||||||
|
if hasattr(text_encoder_type, "text_encoder_type"):
|
||||||
|
text_encoder_type = text_encoder_type.text_encoder_type
|
||||||
|
elif text_encoder_type.get("text_encoder_type", False):
|
||||||
|
text_encoder_type = text_encoder_type.get("text_encoder_type")
|
||||||
|
elif os.path.isdir(text_encoder_type) and os.path.exists(text_encoder_type):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Unknown type of text_encoder_type: {}".format(type(text_encoder_type))
|
||||||
|
)
|
||||||
|
print("final text_encoder_type: {}".format(text_encoder_type))
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(text_encoder_type)
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def get_pretrained_language_model(text_encoder_type):
|
||||||
|
if text_encoder_type == "bert-base-uncased" or (os.path.isdir(text_encoder_type) and os.path.exists(text_encoder_type)):
|
||||||
|
return BertModel.from_pretrained(text_encoder_type)
|
||||||
|
if text_encoder_type == "roberta-base":
|
||||||
|
return RobertaModel.from_pretrained(text_encoder_type)
|
||||||
|
|
||||||
|
raise ValueError("Unknown text_encoder_type {}".format(text_encoder_type))
|
@ -0,0 +1,259 @@
|
|||||||
|
from typing import Tuple, List
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import supervision as sv
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from torchvision.ops import box_convert
|
||||||
|
import bisect
|
||||||
|
|
||||||
|
import groundingdino.datasets.transforms as T
|
||||||
|
from groundingdino.models import build_model
|
||||||
|
from groundingdino.util.misc import clean_state_dict
|
||||||
|
from groundingdino.util.slconfig import SLConfig
|
||||||
|
from groundingdino.util.utils import get_phrases_from_posmap
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------------------------------------------------------
|
||||||
|
# OLD API
|
||||||
|
# ----------------------------------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_caption(caption: str) -> str:
|
||||||
|
result = caption.lower().strip()
|
||||||
|
if result.endswith("."):
|
||||||
|
return result
|
||||||
|
return result + "."
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(model_config_path: str, model_checkpoint_path: str, device: str = "cuda"):
|
||||||
|
args = SLConfig.fromfile(model_config_path)
|
||||||
|
args.device = device
|
||||||
|
model = build_model(args)
|
||||||
|
checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
|
||||||
|
model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
|
||||||
|
model.eval()
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def load_image(image_path: str) -> Tuple[np.array, torch.Tensor]:
|
||||||
|
transform = T.Compose(
|
||||||
|
[
|
||||||
|
T.RandomResize([800], max_size=1333),
|
||||||
|
T.ToTensor(),
|
||||||
|
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
image_source = Image.open(image_path).convert("RGB")
|
||||||
|
image = np.asarray(image_source)
|
||||||
|
image_transformed, _ = transform(image_source, None)
|
||||||
|
return image, image_transformed
|
||||||
|
|
||||||
|
|
||||||
|
def predict(
|
||||||
|
model,
|
||||||
|
image: torch.Tensor,
|
||||||
|
caption: str,
|
||||||
|
box_threshold: float,
|
||||||
|
text_threshold: float,
|
||||||
|
device: str = "cuda",
|
||||||
|
remove_combined: bool = False
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, List[str]]:
|
||||||
|
caption = preprocess_caption(caption=caption)
|
||||||
|
|
||||||
|
model = model.to(device)
|
||||||
|
image = image.to(device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(image[None], captions=[caption])
|
||||||
|
|
||||||
|
prediction_logits = outputs["pred_logits"].cpu().sigmoid()[0] # prediction_logits.shape = (nq, 256)
|
||||||
|
prediction_boxes = outputs["pred_boxes"].cpu()[0] # prediction_boxes.shape = (nq, 4)
|
||||||
|
|
||||||
|
mask = prediction_logits.max(dim=1)[0] > box_threshold
|
||||||
|
logits = prediction_logits[mask] # logits.shape = (n, 256)
|
||||||
|
boxes = prediction_boxes[mask] # boxes.shape = (n, 4)
|
||||||
|
|
||||||
|
tokenizer = model.tokenizer
|
||||||
|
tokenized = tokenizer(caption)
|
||||||
|
|
||||||
|
if remove_combined:
|
||||||
|
sep_idx = [i for i in range(len(tokenized['input_ids'])) if tokenized['input_ids'][i] in [101, 102, 1012]]
|
||||||
|
|
||||||
|
phrases = []
|
||||||
|
for logit in logits:
|
||||||
|
max_idx = logit.argmax()
|
||||||
|
insert_idx = bisect.bisect_left(sep_idx, max_idx)
|
||||||
|
right_idx = sep_idx[insert_idx]
|
||||||
|
left_idx = sep_idx[insert_idx - 1]
|
||||||
|
phrases.append(get_phrases_from_posmap(logit > text_threshold, tokenized, tokenizer, left_idx, right_idx).replace('.', ''))
|
||||||
|
else:
|
||||||
|
phrases = [
|
||||||
|
get_phrases_from_posmap(logit > text_threshold, tokenized, tokenizer).replace('.', '')
|
||||||
|
for logit
|
||||||
|
in logits
|
||||||
|
]
|
||||||
|
|
||||||
|
return boxes, logits.max(dim=1)[0], phrases
|
||||||
|
|
||||||
|
|
||||||
|
def annotate(image_source: np.ndarray, boxes: torch.Tensor, logits: torch.Tensor, phrases: List[str]) -> np.ndarray:
|
||||||
|
h, w, _ = image_source.shape
|
||||||
|
boxes = boxes * torch.Tensor([w, h, w, h])
|
||||||
|
xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
|
||||||
|
detections = sv.Detections(xyxy=xyxy)
|
||||||
|
|
||||||
|
labels = [
|
||||||
|
f"{phrase} {logit:.2f}"
|
||||||
|
for phrase, logit
|
||||||
|
in zip(phrases, logits)
|
||||||
|
]
|
||||||
|
|
||||||
|
box_annotator = sv.BoxAnnotator()
|
||||||
|
annotated_frame = cv2.cvtColor(image_source, cv2.COLOR_RGB2BGR)
|
||||||
|
annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
|
||||||
|
return annotated_frame
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------------------------------------------------------
|
||||||
|
# NEW API
|
||||||
|
# ----------------------------------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class Model:
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_config_path: str,
|
||||||
|
model_checkpoint_path: str,
|
||||||
|
device: str = "cuda"
|
||||||
|
):
|
||||||
|
self.model = load_model(
|
||||||
|
model_config_path=model_config_path,
|
||||||
|
model_checkpoint_path=model_checkpoint_path,
|
||||||
|
device=device
|
||||||
|
).to(device)
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
def predict_with_caption(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
caption: str,
|
||||||
|
box_threshold: float = 0.35,
|
||||||
|
text_threshold: float = 0.25
|
||||||
|
) -> Tuple[sv.Detections, List[str]]:
|
||||||
|
"""
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
image = cv2.imread(IMAGE_PATH)
|
||||||
|
|
||||||
|
model = Model(model_config_path=CONFIG_PATH, model_checkpoint_path=WEIGHTS_PATH)
|
||||||
|
detections, labels = model.predict_with_caption(
|
||||||
|
image=image,
|
||||||
|
caption=caption,
|
||||||
|
box_threshold=BOX_THRESHOLD,
|
||||||
|
text_threshold=TEXT_THRESHOLD
|
||||||
|
)
|
||||||
|
|
||||||
|
import supervision as sv
|
||||||
|
|
||||||
|
box_annotator = sv.BoxAnnotator()
|
||||||
|
annotated_image = box_annotator.annotate(scene=image, detections=detections, labels=labels)
|
||||||
|
"""
|
||||||
|
processed_image = Model.preprocess_image(image_bgr=image).to(self.device)
|
||||||
|
boxes, logits, phrases = predict(
|
||||||
|
model=self.model,
|
||||||
|
image=processed_image,
|
||||||
|
caption=caption,
|
||||||
|
box_threshold=box_threshold,
|
||||||
|
text_threshold=text_threshold,
|
||||||
|
device=self.device)
|
||||||
|
source_h, source_w, _ = image.shape
|
||||||
|
detections = Model.post_process_result(
|
||||||
|
source_h=source_h,
|
||||||
|
source_w=source_w,
|
||||||
|
boxes=boxes,
|
||||||
|
logits=logits)
|
||||||
|
return detections, phrases
|
||||||
|
|
||||||
|
def predict_with_classes(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
classes: List[str],
|
||||||
|
box_threshold: float,
|
||||||
|
text_threshold: float
|
||||||
|
) -> sv.Detections:
|
||||||
|
"""
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
image = cv2.imread(IMAGE_PATH)
|
||||||
|
|
||||||
|
model = Model(model_config_path=CONFIG_PATH, model_checkpoint_path=WEIGHTS_PATH)
|
||||||
|
detections = model.predict_with_classes(
|
||||||
|
image=image,
|
||||||
|
classes=CLASSES,
|
||||||
|
box_threshold=BOX_THRESHOLD,
|
||||||
|
text_threshold=TEXT_THRESHOLD
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
import supervision as sv
|
||||||
|
|
||||||
|
box_annotator = sv.BoxAnnotator()
|
||||||
|
annotated_image = box_annotator.annotate(scene=image, detections=detections)
|
||||||
|
"""
|
||||||
|
caption = ". ".join(classes)
|
||||||
|
processed_image = Model.preprocess_image(image_bgr=image).to(self.device)
|
||||||
|
boxes, logits, phrases = predict(
|
||||||
|
model=self.model,
|
||||||
|
image=processed_image,
|
||||||
|
caption=caption,
|
||||||
|
box_threshold=box_threshold,
|
||||||
|
text_threshold=text_threshold,
|
||||||
|
device=self.device)
|
||||||
|
source_h, source_w, _ = image.shape
|
||||||
|
detections = Model.post_process_result(
|
||||||
|
source_h=source_h,
|
||||||
|
source_w=source_w,
|
||||||
|
boxes=boxes,
|
||||||
|
logits=logits)
|
||||||
|
class_id = Model.phrases2classes(phrases=phrases, classes=classes)
|
||||||
|
detections.class_id = class_id
|
||||||
|
return detections
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def preprocess_image(image_bgr: np.ndarray) -> torch.Tensor:
|
||||||
|
transform = T.Compose(
|
||||||
|
[
|
||||||
|
T.RandomResize([800], max_size=1333),
|
||||||
|
T.ToTensor(),
|
||||||
|
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
image_pillow = Image.fromarray(cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB))
|
||||||
|
image_transformed, _ = transform(image_pillow, None)
|
||||||
|
return image_transformed
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def post_process_result(
|
||||||
|
source_h: int,
|
||||||
|
source_w: int,
|
||||||
|
boxes: torch.Tensor,
|
||||||
|
logits: torch.Tensor
|
||||||
|
) -> sv.Detections:
|
||||||
|
boxes = boxes * torch.Tensor([source_w, source_h, source_w, source_h])
|
||||||
|
xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
|
||||||
|
confidence = logits.numpy()
|
||||||
|
return sv.Detections(xyxy=xyxy, confidence=confidence)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def phrases2classes(phrases: List[str], classes: List[str]) -> np.ndarray:
|
||||||
|
class_ids = []
|
||||||
|
for phrase in phrases:
|
||||||
|
for class_ in classes:
|
||||||
|
if class_ in phrase:
|
||||||
|
class_ids.append(classes.index(class_))
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
class_ids.append(None)
|
||||||
|
return np.array(class_ids)
|
@ -0,0 +1,93 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||||
|
import functools
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from termcolor import colored
|
||||||
|
|
||||||
|
|
||||||
|
class _ColorfulFormatter(logging.Formatter):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
self._root_name = kwargs.pop("root_name") + "."
|
||||||
|
self._abbrev_name = kwargs.pop("abbrev_name", "")
|
||||||
|
if len(self._abbrev_name):
|
||||||
|
self._abbrev_name = self._abbrev_name + "."
|
||||||
|
super(_ColorfulFormatter, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def formatMessage(self, record):
|
||||||
|
record.name = record.name.replace(self._root_name, self._abbrev_name)
|
||||||
|
log = super(_ColorfulFormatter, self).formatMessage(record)
|
||||||
|
if record.levelno == logging.WARNING:
|
||||||
|
prefix = colored("WARNING", "red", attrs=["blink"])
|
||||||
|
elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL:
|
||||||
|
prefix = colored("ERROR", "red", attrs=["blink", "underline"])
|
||||||
|
else:
|
||||||
|
return log
|
||||||
|
return prefix + " " + log
|
||||||
|
|
||||||
|
|
||||||
|
# so that calling setup_logger multiple times won't add many handlers
|
||||||
|
@functools.lru_cache()
|
||||||
|
def setup_logger(output=None, distributed_rank=0, *, color=True, name="imagenet", abbrev_name=None):
|
||||||
|
"""
|
||||||
|
Initialize the detectron2 logger and set its verbosity level to "INFO".
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output (str): a file name or a directory to save log. If None, will not save log file.
|
||||||
|
If ends with ".txt" or ".log", assumed to be a file name.
|
||||||
|
Otherwise, logs will be saved to `output/log.txt`.
|
||||||
|
name (str): the root module name of this logger
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
logging.Logger: a logger
|
||||||
|
"""
|
||||||
|
logger = logging.getLogger(name)
|
||||||
|
logger.setLevel(logging.DEBUG)
|
||||||
|
logger.propagate = False
|
||||||
|
|
||||||
|
if abbrev_name is None:
|
||||||
|
abbrev_name = name
|
||||||
|
|
||||||
|
plain_formatter = logging.Formatter(
|
||||||
|
"[%(asctime)s.%(msecs)03d]: %(message)s", datefmt="%m/%d %H:%M:%S"
|
||||||
|
)
|
||||||
|
# stdout logging: master only
|
||||||
|
if distributed_rank == 0:
|
||||||
|
ch = logging.StreamHandler(stream=sys.stdout)
|
||||||
|
ch.setLevel(logging.DEBUG)
|
||||||
|
if color:
|
||||||
|
formatter = _ColorfulFormatter(
|
||||||
|
colored("[%(asctime)s.%(msecs)03d]: ", "green") + "%(message)s",
|
||||||
|
datefmt="%m/%d %H:%M:%S",
|
||||||
|
root_name=name,
|
||||||
|
abbrev_name=str(abbrev_name),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
formatter = plain_formatter
|
||||||
|
ch.setFormatter(formatter)
|
||||||
|
logger.addHandler(ch)
|
||||||
|
|
||||||
|
# file logging: all workers
|
||||||
|
if output is not None:
|
||||||
|
if output.endswith(".txt") or output.endswith(".log"):
|
||||||
|
filename = output
|
||||||
|
else:
|
||||||
|
filename = os.path.join(output, "log.txt")
|
||||||
|
if distributed_rank > 0:
|
||||||
|
filename = filename + f".rank{distributed_rank}"
|
||||||
|
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
||||||
|
|
||||||
|
fh = logging.StreamHandler(_cached_log_stream(filename))
|
||||||
|
fh.setLevel(logging.DEBUG)
|
||||||
|
fh.setFormatter(plain_formatter)
|
||||||
|
logger.addHandler(fh)
|
||||||
|
|
||||||
|
return logger
|
||||||
|
|
||||||
|
|
||||||
|
# cache the opened file object, so that different calls to `setup_logger`
|
||||||
|
# with the same file name can safely write to the same file.
|
||||||
|
@functools.lru_cache(maxsize=None)
|
||||||
|
def _cached_log_stream(filename):
|
||||||
|
return open(filename, "a")
|
@ -0,0 +1,717 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||||
|
"""
|
||||||
|
Misc functions, including distributed helpers.
|
||||||
|
|
||||||
|
Mostly copy-paste from torchvision references.
|
||||||
|
"""
|
||||||
|
import colorsys
|
||||||
|
import datetime
|
||||||
|
import functools
|
||||||
|
import io
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
import subprocess
|
||||||
|
import time
|
||||||
|
from collections import OrderedDict, defaultdict, deque
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
# needed due to empty tensor bug in pytorch and torchvision 0.5
|
||||||
|
import torchvision
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
__torchvision_need_compat_flag = float(torchvision.__version__.split(".")[1]) < 7
|
||||||
|
if __torchvision_need_compat_flag:
|
||||||
|
from torchvision.ops import _new_empty_tensor
|
||||||
|
from torchvision.ops.misc import _output_size
|
||||||
|
|
||||||
|
|
||||||
|
class SmoothedValue(object):
|
||||||
|
"""Track a series of values and provide access to smoothed values over a
|
||||||
|
window or the global series average.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, window_size=20, fmt=None):
|
||||||
|
if fmt is None:
|
||||||
|
fmt = "{median:.4f} ({global_avg:.4f})"
|
||||||
|
self.deque = deque(maxlen=window_size)
|
||||||
|
self.total = 0.0
|
||||||
|
self.count = 0
|
||||||
|
self.fmt = fmt
|
||||||
|
|
||||||
|
def update(self, value, n=1):
|
||||||
|
self.deque.append(value)
|
||||||
|
self.count += n
|
||||||
|
self.total += value * n
|
||||||
|
|
||||||
|
def synchronize_between_processes(self):
|
||||||
|
"""
|
||||||
|
Warning: does not synchronize the deque!
|
||||||
|
"""
|
||||||
|
if not is_dist_avail_and_initialized():
|
||||||
|
return
|
||||||
|
t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
|
||||||
|
dist.barrier()
|
||||||
|
dist.all_reduce(t)
|
||||||
|
t = t.tolist()
|
||||||
|
self.count = int(t[0])
|
||||||
|
self.total = t[1]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def median(self):
|
||||||
|
d = torch.tensor(list(self.deque))
|
||||||
|
if d.shape[0] == 0:
|
||||||
|
return 0
|
||||||
|
return d.median().item()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def avg(self):
|
||||||
|
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
||||||
|
return d.mean().item()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def global_avg(self):
|
||||||
|
if os.environ.get("SHILONG_AMP", None) == "1":
|
||||||
|
eps = 1e-4
|
||||||
|
else:
|
||||||
|
eps = 1e-6
|
||||||
|
return self.total / (self.count + eps)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max(self):
|
||||||
|
return max(self.deque)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def value(self):
|
||||||
|
return self.deque[-1]
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.fmt.format(
|
||||||
|
median=self.median,
|
||||||
|
avg=self.avg,
|
||||||
|
global_avg=self.global_avg,
|
||||||
|
max=self.max,
|
||||||
|
value=self.value,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@functools.lru_cache()
|
||||||
|
def _get_global_gloo_group():
|
||||||
|
"""
|
||||||
|
Return a process group based on gloo backend, containing all the ranks
|
||||||
|
The result is cached.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if dist.get_backend() == "nccl":
|
||||||
|
return dist.new_group(backend="gloo")
|
||||||
|
|
||||||
|
return dist.group.WORLD
|
||||||
|
|
||||||
|
|
||||||
|
def all_gather_cpu(data):
|
||||||
|
"""
|
||||||
|
Run all_gather on arbitrary picklable data (not necessarily tensors)
|
||||||
|
Args:
|
||||||
|
data: any picklable object
|
||||||
|
Returns:
|
||||||
|
list[data]: list of data gathered from each rank
|
||||||
|
"""
|
||||||
|
|
||||||
|
world_size = get_world_size()
|
||||||
|
if world_size == 1:
|
||||||
|
return [data]
|
||||||
|
|
||||||
|
cpu_group = _get_global_gloo_group()
|
||||||
|
|
||||||
|
buffer = io.BytesIO()
|
||||||
|
torch.save(data, buffer)
|
||||||
|
data_view = buffer.getbuffer()
|
||||||
|
device = "cuda" if cpu_group is None else "cpu"
|
||||||
|
tensor = torch.ByteTensor(data_view).to(device)
|
||||||
|
|
||||||
|
# obtain Tensor size of each rank
|
||||||
|
local_size = torch.tensor([tensor.numel()], device=device, dtype=torch.long)
|
||||||
|
size_list = [torch.tensor([0], device=device, dtype=torch.long) for _ in range(world_size)]
|
||||||
|
if cpu_group is None:
|
||||||
|
dist.all_gather(size_list, local_size)
|
||||||
|
else:
|
||||||
|
print("gathering on cpu")
|
||||||
|
dist.all_gather(size_list, local_size, group=cpu_group)
|
||||||
|
size_list = [int(size.item()) for size in size_list]
|
||||||
|
max_size = max(size_list)
|
||||||
|
assert isinstance(local_size.item(), int)
|
||||||
|
local_size = int(local_size.item())
|
||||||
|
|
||||||
|
# receiving Tensor from all ranks
|
||||||
|
# we pad the tensor because torch all_gather does not support
|
||||||
|
# gathering tensors of different shapes
|
||||||
|
tensor_list = []
|
||||||
|
for _ in size_list:
|
||||||
|
tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device=device))
|
||||||
|
if local_size != max_size:
|
||||||
|
padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device=device)
|
||||||
|
tensor = torch.cat((tensor, padding), dim=0)
|
||||||
|
if cpu_group is None:
|
||||||
|
dist.all_gather(tensor_list, tensor)
|
||||||
|
else:
|
||||||
|
dist.all_gather(tensor_list, tensor, group=cpu_group)
|
||||||
|
|
||||||
|
data_list = []
|
||||||
|
for size, tensor in zip(size_list, tensor_list):
|
||||||
|
tensor = torch.split(tensor, [size, max_size - size], dim=0)[0]
|
||||||
|
buffer = io.BytesIO(tensor.cpu().numpy())
|
||||||
|
obj = torch.load(buffer)
|
||||||
|
data_list.append(obj)
|
||||||
|
|
||||||
|
return data_list
|
||||||
|
|
||||||
|
|
||||||
|
def all_gather(data):
|
||||||
|
"""
|
||||||
|
Run all_gather on arbitrary picklable data (not necessarily tensors)
|
||||||
|
Args:
|
||||||
|
data: any picklable object
|
||||||
|
Returns:
|
||||||
|
list[data]: list of data gathered from each rank
|
||||||
|
"""
|
||||||
|
|
||||||
|
if os.getenv("CPU_REDUCE") == "1":
|
||||||
|
return all_gather_cpu(data)
|
||||||
|
|
||||||
|
world_size = get_world_size()
|
||||||
|
if world_size == 1:
|
||||||
|
return [data]
|
||||||
|
|
||||||
|
# serialized to a Tensor
|
||||||
|
buffer = pickle.dumps(data)
|
||||||
|
storage = torch.ByteStorage.from_buffer(buffer)
|
||||||
|
tensor = torch.ByteTensor(storage).to("cuda")
|
||||||
|
|
||||||
|
# obtain Tensor size of each rank
|
||||||
|
local_size = torch.tensor([tensor.numel()], device="cuda")
|
||||||
|
size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
|
||||||
|
dist.all_gather(size_list, local_size)
|
||||||
|
size_list = [int(size.item()) for size in size_list]
|
||||||
|
max_size = max(size_list)
|
||||||
|
|
||||||
|
# receiving Tensor from all ranks
|
||||||
|
# we pad the tensor because torch all_gather does not support
|
||||||
|
# gathering tensors of different shapes
|
||||||
|
tensor_list = []
|
||||||
|
for _ in size_list:
|
||||||
|
tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
|
||||||
|
if local_size != max_size:
|
||||||
|
padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
|
||||||
|
tensor = torch.cat((tensor, padding), dim=0)
|
||||||
|
dist.all_gather(tensor_list, tensor)
|
||||||
|
|
||||||
|
data_list = []
|
||||||
|
for size, tensor in zip(size_list, tensor_list):
|
||||||
|
buffer = tensor.cpu().numpy().tobytes()[:size]
|
||||||
|
data_list.append(pickle.loads(buffer))
|
||||||
|
|
||||||
|
return data_list
|
||||||
|
|
||||||
|
|
||||||
|
def reduce_dict(input_dict, average=True):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
input_dict (dict): all the values will be reduced
|
||||||
|
average (bool): whether to do average or sum
|
||||||
|
Reduce the values in the dictionary from all processes so that all processes
|
||||||
|
have the averaged results. Returns a dict with the same fields as
|
||||||
|
input_dict, after reduction.
|
||||||
|
"""
|
||||||
|
world_size = get_world_size()
|
||||||
|
if world_size < 2:
|
||||||
|
return input_dict
|
||||||
|
with torch.no_grad():
|
||||||
|
names = []
|
||||||
|
values = []
|
||||||
|
# sort the keys so that they are consistent across processes
|
||||||
|
for k in sorted(input_dict.keys()):
|
||||||
|
names.append(k)
|
||||||
|
values.append(input_dict[k])
|
||||||
|
values = torch.stack(values, dim=0)
|
||||||
|
dist.all_reduce(values)
|
||||||
|
if average:
|
||||||
|
values /= world_size
|
||||||
|
reduced_dict = {k: v for k, v in zip(names, values)}
|
||||||
|
return reduced_dict
|
||||||
|
|
||||||
|
|
||||||
|
class MetricLogger(object):
|
||||||
|
def __init__(self, delimiter="\t"):
|
||||||
|
self.meters = defaultdict(SmoothedValue)
|
||||||
|
self.delimiter = delimiter
|
||||||
|
|
||||||
|
def update(self, **kwargs):
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
if isinstance(v, torch.Tensor):
|
||||||
|
v = v.item()
|
||||||
|
assert isinstance(v, (float, int))
|
||||||
|
self.meters[k].update(v)
|
||||||
|
|
||||||
|
def __getattr__(self, attr):
|
||||||
|
if attr in self.meters:
|
||||||
|
return self.meters[attr]
|
||||||
|
if attr in self.__dict__:
|
||||||
|
return self.__dict__[attr]
|
||||||
|
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, attr))
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
loss_str = []
|
||||||
|
for name, meter in self.meters.items():
|
||||||
|
# print(name, str(meter))
|
||||||
|
# import ipdb;ipdb.set_trace()
|
||||||
|
if meter.count > 0:
|
||||||
|
loss_str.append("{}: {}".format(name, str(meter)))
|
||||||
|
return self.delimiter.join(loss_str)
|
||||||
|
|
||||||
|
def synchronize_between_processes(self):
|
||||||
|
for meter in self.meters.values():
|
||||||
|
meter.synchronize_between_processes()
|
||||||
|
|
||||||
|
def add_meter(self, name, meter):
|
||||||
|
self.meters[name] = meter
|
||||||
|
|
||||||
|
def log_every(self, iterable, print_freq, header=None, logger=None):
|
||||||
|
if logger is None:
|
||||||
|
print_func = print
|
||||||
|
else:
|
||||||
|
print_func = logger.info
|
||||||
|
|
||||||
|
i = 0
|
||||||
|
if not header:
|
||||||
|
header = ""
|
||||||
|
start_time = time.time()
|
||||||
|
end = time.time()
|
||||||
|
iter_time = SmoothedValue(fmt="{avg:.4f}")
|
||||||
|
data_time = SmoothedValue(fmt="{avg:.4f}")
|
||||||
|
space_fmt = ":" + str(len(str(len(iterable)))) + "d"
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
log_msg = self.delimiter.join(
|
||||||
|
[
|
||||||
|
header,
|
||||||
|
"[{0" + space_fmt + "}/{1}]",
|
||||||
|
"eta: {eta}",
|
||||||
|
"{meters}",
|
||||||
|
"time: {time}",
|
||||||
|
"data: {data}",
|
||||||
|
"max mem: {memory:.0f}",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
log_msg = self.delimiter.join(
|
||||||
|
[
|
||||||
|
header,
|
||||||
|
"[{0" + space_fmt + "}/{1}]",
|
||||||
|
"eta: {eta}",
|
||||||
|
"{meters}",
|
||||||
|
"time: {time}",
|
||||||
|
"data: {data}",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
MB = 1024.0 * 1024.0
|
||||||
|
for obj in iterable:
|
||||||
|
data_time.update(time.time() - end)
|
||||||
|
yield obj
|
||||||
|
# import ipdb; ipdb.set_trace()
|
||||||
|
iter_time.update(time.time() - end)
|
||||||
|
if i % print_freq == 0 or i == len(iterable) - 1:
|
||||||
|
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
||||||
|
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
print_func(
|
||||||
|
log_msg.format(
|
||||||
|
i,
|
||||||
|
len(iterable),
|
||||||
|
eta=eta_string,
|
||||||
|
meters=str(self),
|
||||||
|
time=str(iter_time),
|
||||||
|
data=str(data_time),
|
||||||
|
memory=torch.cuda.max_memory_allocated() / MB,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print_func(
|
||||||
|
log_msg.format(
|
||||||
|
i,
|
||||||
|
len(iterable),
|
||||||
|
eta=eta_string,
|
||||||
|
meters=str(self),
|
||||||
|
time=str(iter_time),
|
||||||
|
data=str(data_time),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
i += 1
|
||||||
|
end = time.time()
|
||||||
|
total_time = time.time() - start_time
|
||||||
|
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
||||||
|
print_func(
|
||||||
|
"{} Total time: {} ({:.4f} s / it)".format(
|
||||||
|
header, total_time_str, total_time / len(iterable)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_sha():
|
||||||
|
cwd = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
|
def _run(command):
|
||||||
|
return subprocess.check_output(command, cwd=cwd).decode("ascii").strip()
|
||||||
|
|
||||||
|
sha = "N/A"
|
||||||
|
diff = "clean"
|
||||||
|
branch = "N/A"
|
||||||
|
try:
|
||||||
|
sha = _run(["git", "rev-parse", "HEAD"])
|
||||||
|
subprocess.check_output(["git", "diff"], cwd=cwd)
|
||||||
|
diff = _run(["git", "diff-index", "HEAD"])
|
||||||
|
diff = "has uncommited changes" if diff else "clean"
|
||||||
|
branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"])
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
message = f"sha: {sha}, status: {diff}, branch: {branch}"
|
||||||
|
return message
|
||||||
|
|
||||||
|
|
||||||
|
def collate_fn(batch):
|
||||||
|
# import ipdb; ipdb.set_trace()
|
||||||
|
batch = list(zip(*batch))
|
||||||
|
batch[0] = nested_tensor_from_tensor_list(batch[0])
|
||||||
|
return tuple(batch)
|
||||||
|
|
||||||
|
|
||||||
|
def _max_by_axis(the_list):
|
||||||
|
# type: (List[List[int]]) -> List[int]
|
||||||
|
maxes = the_list[0]
|
||||||
|
for sublist in the_list[1:]:
|
||||||
|
for index, item in enumerate(sublist):
|
||||||
|
maxes[index] = max(maxes[index], item)
|
||||||
|
return maxes
|
||||||
|
|
||||||
|
|
||||||
|
class NestedTensor(object):
|
||||||
|
def __init__(self, tensors, mask: Optional[Tensor]):
|
||||||
|
self.tensors = tensors
|
||||||
|
self.mask = mask
|
||||||
|
if mask == "auto":
|
||||||
|
self.mask = torch.zeros_like(tensors).to(tensors.device)
|
||||||
|
if self.mask.dim() == 3:
|
||||||
|
self.mask = self.mask.sum(0).to(bool)
|
||||||
|
elif self.mask.dim() == 4:
|
||||||
|
self.mask = self.mask.sum(1).to(bool)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"tensors dim must be 3 or 4 but {}({})".format(
|
||||||
|
self.tensors.dim(), self.tensors.shape
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def imgsize(self):
|
||||||
|
res = []
|
||||||
|
for i in range(self.tensors.shape[0]):
|
||||||
|
mask = self.mask[i]
|
||||||
|
maxH = (~mask).sum(0).max()
|
||||||
|
maxW = (~mask).sum(1).max()
|
||||||
|
res.append(torch.Tensor([maxH, maxW]))
|
||||||
|
return res
|
||||||
|
|
||||||
|
def to(self, device):
|
||||||
|
# type: (Device) -> NestedTensor # noqa
|
||||||
|
cast_tensor = self.tensors.to(device)
|
||||||
|
mask = self.mask
|
||||||
|
if mask is not None:
|
||||||
|
assert mask is not None
|
||||||
|
cast_mask = mask.to(device)
|
||||||
|
else:
|
||||||
|
cast_mask = None
|
||||||
|
return NestedTensor(cast_tensor, cast_mask)
|
||||||
|
|
||||||
|
def to_img_list_single(self, tensor, mask):
|
||||||
|
assert tensor.dim() == 3, "dim of tensor should be 3 but {}".format(tensor.dim())
|
||||||
|
maxH = (~mask).sum(0).max()
|
||||||
|
maxW = (~mask).sum(1).max()
|
||||||
|
img = tensor[:, :maxH, :maxW]
|
||||||
|
return img
|
||||||
|
|
||||||
|
def to_img_list(self):
|
||||||
|
"""remove the padding and convert to img list
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[type]: [description]
|
||||||
|
"""
|
||||||
|
if self.tensors.dim() == 3:
|
||||||
|
return self.to_img_list_single(self.tensors, self.mask)
|
||||||
|
else:
|
||||||
|
res = []
|
||||||
|
for i in range(self.tensors.shape[0]):
|
||||||
|
tensor_i = self.tensors[i]
|
||||||
|
mask_i = self.mask[i]
|
||||||
|
res.append(self.to_img_list_single(tensor_i, mask_i))
|
||||||
|
return res
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self):
|
||||||
|
return self.tensors.device
|
||||||
|
|
||||||
|
def decompose(self):
|
||||||
|
return self.tensors, self.mask
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return str(self.tensors)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def shape(self):
|
||||||
|
return {"tensors.shape": self.tensors.shape, "mask.shape": self.mask.shape}
|
||||||
|
|
||||||
|
|
||||||
|
def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
|
||||||
|
# TODO make this more general
|
||||||
|
if tensor_list[0].ndim == 3:
|
||||||
|
if torchvision._is_tracing():
|
||||||
|
# nested_tensor_from_tensor_list() does not export well to ONNX
|
||||||
|
# call _onnx_nested_tensor_from_tensor_list() instead
|
||||||
|
return _onnx_nested_tensor_from_tensor_list(tensor_list)
|
||||||
|
|
||||||
|
# TODO make it support different-sized images
|
||||||
|
max_size = _max_by_axis([list(img.shape) for img in tensor_list])
|
||||||
|
# min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
|
||||||
|
batch_shape = [len(tensor_list)] + max_size
|
||||||
|
b, c, h, w = batch_shape
|
||||||
|
dtype = tensor_list[0].dtype
|
||||||
|
device = tensor_list[0].device
|
||||||
|
tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
|
||||||
|
mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
|
||||||
|
for img, pad_img, m in zip(tensor_list, tensor, mask):
|
||||||
|
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
|
||||||
|
m[: img.shape[1], : img.shape[2]] = False
|
||||||
|
else:
|
||||||
|
raise ValueError("not supported")
|
||||||
|
return NestedTensor(tensor, mask)
|
||||||
|
|
||||||
|
|
||||||
|
# _onnx_nested_tensor_from_tensor_list() is an implementation of
|
||||||
|
# nested_tensor_from_tensor_list() that is supported by ONNX tracing.
|
||||||
|
@torch.jit.unused
|
||||||
|
def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor:
|
||||||
|
max_size = []
|
||||||
|
for i in range(tensor_list[0].dim()):
|
||||||
|
max_size_i = torch.max(
|
||||||
|
torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)
|
||||||
|
).to(torch.int64)
|
||||||
|
max_size.append(max_size_i)
|
||||||
|
max_size = tuple(max_size)
|
||||||
|
|
||||||
|
# work around for
|
||||||
|
# pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
|
||||||
|
# m[: img.shape[1], :img.shape[2]] = False
|
||||||
|
# which is not yet supported in onnx
|
||||||
|
padded_imgs = []
|
||||||
|
padded_masks = []
|
||||||
|
for img in tensor_list:
|
||||||
|
padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
|
||||||
|
padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
|
||||||
|
padded_imgs.append(padded_img)
|
||||||
|
|
||||||
|
m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
|
||||||
|
padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1)
|
||||||
|
padded_masks.append(padded_mask.to(torch.bool))
|
||||||
|
|
||||||
|
tensor = torch.stack(padded_imgs)
|
||||||
|
mask = torch.stack(padded_masks)
|
||||||
|
|
||||||
|
return NestedTensor(tensor, mask=mask)
|
||||||
|
|
||||||
|
|
||||||
|
def setup_for_distributed(is_master):
|
||||||
|
"""
|
||||||
|
This function disables printing when not in master process
|
||||||
|
"""
|
||||||
|
import builtins as __builtin__
|
||||||
|
|
||||||
|
builtin_print = __builtin__.print
|
||||||
|
|
||||||
|
def print(*args, **kwargs):
|
||||||
|
force = kwargs.pop("force", False)
|
||||||
|
if is_master or force:
|
||||||
|
builtin_print(*args, **kwargs)
|
||||||
|
|
||||||
|
__builtin__.print = print
|
||||||
|
|
||||||
|
|
||||||
|
def is_dist_avail_and_initialized():
|
||||||
|
if not dist.is_available():
|
||||||
|
return False
|
||||||
|
if not dist.is_initialized():
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def get_world_size():
|
||||||
|
if not is_dist_avail_and_initialized():
|
||||||
|
return 1
|
||||||
|
return dist.get_world_size()
|
||||||
|
|
||||||
|
|
||||||
|
def get_rank():
|
||||||
|
if not is_dist_avail_and_initialized():
|
||||||
|
return 0
|
||||||
|
return dist.get_rank()
|
||||||
|
|
||||||
|
|
||||||
|
def is_main_process():
|
||||||
|
return get_rank() == 0
|
||||||
|
|
||||||
|
|
||||||
|
def save_on_master(*args, **kwargs):
|
||||||
|
if is_main_process():
|
||||||
|
torch.save(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def init_distributed_mode(args):
|
||||||
|
if "WORLD_SIZE" in os.environ and os.environ["WORLD_SIZE"] != "": # 'RANK' in os.environ and
|
||||||
|
args.rank = int(os.environ["RANK"])
|
||||||
|
args.world_size = int(os.environ["WORLD_SIZE"])
|
||||||
|
args.gpu = args.local_rank = int(os.environ["LOCAL_RANK"])
|
||||||
|
|
||||||
|
# launch by torch.distributed.launch
|
||||||
|
# Single node
|
||||||
|
# python -m torch.distributed.launch --nproc_per_node=8 main.py --world-size 1 --rank 0 ...
|
||||||
|
# Multi nodes
|
||||||
|
# python -m torch.distributed.launch --nproc_per_node=8 main.py --world-size 2 --rank 0 --dist-url 'tcp://IP_OF_NODE0:FREEPORT' ...
|
||||||
|
# python -m torch.distributed.launch --nproc_per_node=8 main.py --world-size 2 --rank 1 --dist-url 'tcp://IP_OF_NODE0:FREEPORT' ...
|
||||||
|
# args.rank = int(os.environ.get('OMPI_COMM_WORLD_RANK'))
|
||||||
|
# local_world_size = int(os.environ['GPU_PER_NODE_COUNT'])
|
||||||
|
# args.world_size = args.world_size * local_world_size
|
||||||
|
# args.gpu = args.local_rank = int(os.environ['LOCAL_RANK'])
|
||||||
|
# args.rank = args.rank * local_world_size + args.local_rank
|
||||||
|
print(
|
||||||
|
"world size: {}, rank: {}, local rank: {}".format(
|
||||||
|
args.world_size, args.rank, args.local_rank
|
||||||
|
)
|
||||||
|
)
|
||||||
|
print(json.dumps(dict(os.environ), indent=2))
|
||||||
|
elif "SLURM_PROCID" in os.environ:
|
||||||
|
args.rank = int(os.environ["SLURM_PROCID"])
|
||||||
|
args.gpu = args.local_rank = int(os.environ["SLURM_LOCALID"])
|
||||||
|
args.world_size = int(os.environ["SLURM_NPROCS"])
|
||||||
|
|
||||||
|
print(
|
||||||
|
"world size: {}, world rank: {}, local rank: {}, device_count: {}".format(
|
||||||
|
args.world_size, args.rank, args.local_rank, torch.cuda.device_count()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print("Not using distributed mode")
|
||||||
|
args.distributed = False
|
||||||
|
args.world_size = 1
|
||||||
|
args.rank = 0
|
||||||
|
args.local_rank = 0
|
||||||
|
return
|
||||||
|
|
||||||
|
print("world_size:{} rank:{} local_rank:{}".format(args.world_size, args.rank, args.local_rank))
|
||||||
|
args.distributed = True
|
||||||
|
torch.cuda.set_device(args.local_rank)
|
||||||
|
args.dist_backend = "nccl"
|
||||||
|
print("| distributed init (rank {}): {}".format(args.rank, args.dist_url), flush=True)
|
||||||
|
|
||||||
|
torch.distributed.init_process_group(
|
||||||
|
backend=args.dist_backend,
|
||||||
|
world_size=args.world_size,
|
||||||
|
rank=args.rank,
|
||||||
|
init_method=args.dist_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Before torch.distributed.barrier()")
|
||||||
|
torch.distributed.barrier()
|
||||||
|
print("End torch.distributed.barrier()")
|
||||||
|
setup_for_distributed(args.rank == 0)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def accuracy(output, target, topk=(1,)):
|
||||||
|
"""Computes the precision@k for the specified values of k"""
|
||||||
|
if target.numel() == 0:
|
||||||
|
return [torch.zeros([], device=output.device)]
|
||||||
|
maxk = max(topk)
|
||||||
|
batch_size = target.size(0)
|
||||||
|
|
||||||
|
_, pred = output.topk(maxk, 1, True, True)
|
||||||
|
pred = pred.t()
|
||||||
|
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
||||||
|
|
||||||
|
res = []
|
||||||
|
for k in topk:
|
||||||
|
correct_k = correct[:k].view(-1).float().sum(0)
|
||||||
|
res.append(correct_k.mul_(100.0 / batch_size))
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def accuracy_onehot(pred, gt):
|
||||||
|
"""_summary_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pred (_type_): n, c
|
||||||
|
gt (_type_): n, c
|
||||||
|
"""
|
||||||
|
tp = ((pred - gt).abs().sum(-1) < 1e-4).float().sum()
|
||||||
|
acc = tp / gt.shape[0] * 100
|
||||||
|
return acc
|
||||||
|
|
||||||
|
|
||||||
|
def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None):
|
||||||
|
# type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
|
||||||
|
"""
|
||||||
|
Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
|
||||||
|
This will eventually be supported natively by PyTorch, and this
|
||||||
|
class can go away.
|
||||||
|
"""
|
||||||
|
if __torchvision_need_compat_flag < 0.7:
|
||||||
|
if input.numel() > 0:
|
||||||
|
return torch.nn.functional.interpolate(input, size, scale_factor, mode, align_corners)
|
||||||
|
|
||||||
|
output_shape = _output_size(2, input, size, scale_factor)
|
||||||
|
output_shape = list(input.shape[:-2]) + list(output_shape)
|
||||||
|
return _new_empty_tensor(input, output_shape)
|
||||||
|
else:
|
||||||
|
return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners)
|
||||||
|
|
||||||
|
|
||||||
|
class color_sys:
|
||||||
|
def __init__(self, num_colors) -> None:
|
||||||
|
self.num_colors = num_colors
|
||||||
|
colors = []
|
||||||
|
for i in np.arange(0.0, 360.0, 360.0 / num_colors):
|
||||||
|
hue = i / 360.0
|
||||||
|
lightness = (50 + np.random.rand() * 10) / 100.0
|
||||||
|
saturation = (90 + np.random.rand() * 10) / 100.0
|
||||||
|
colors.append(
|
||||||
|
tuple([int(j * 255) for j in colorsys.hls_to_rgb(hue, lightness, saturation)])
|
||||||
|
)
|
||||||
|
self.colors = colors
|
||||||
|
|
||||||
|
def __call__(self, idx):
|
||||||
|
return self.colors[idx]
|
||||||
|
|
||||||
|
|
||||||
|
def inverse_sigmoid(x, eps=1e-3):
|
||||||
|
x = x.clamp(min=0, max=1)
|
||||||
|
x1 = x.clamp(min=eps)
|
||||||
|
x2 = (1 - x).clamp(min=eps)
|
||||||
|
return torch.log(x1 / x2)
|
||||||
|
|
||||||
|
|
||||||
|
def clean_state_dict(state_dict):
|
||||||
|
new_state_dict = OrderedDict()
|
||||||
|
for k, v in state_dict.items():
|
||||||
|
if k[:7] == "module.":
|
||||||
|
k = k[7:] # remove `module.`
|
||||||
|
new_state_dict[k] = v
|
||||||
|
return new_state_dict
|