parent
3bc51e4046
commit
49350f308c
@ -0,0 +1,197 @@
|
||||
import asyncio
|
||||
import argparse
|
||||
from collections import Counter
|
||||
import json
|
||||
import pathlib
|
||||
import re
|
||||
|
||||
|
||||
import discord
|
||||
from discord.ext import commands
|
||||
import gradio as gr
|
||||
from gradio import utils
|
||||
import requests
|
||||
|
||||
from typing import Dict, List
|
||||
|
||||
from utils import *
|
||||
|
||||
|
||||
lock = asyncio.Lock()
|
||||
|
||||
bot = commands.Bot("", intents=discord.Intents(messages=True, guilds=True))
|
||||
|
||||
|
||||
GUILD_SPACES_FILE = "guild_spaces.pkl"
|
||||
|
||||
|
||||
if pathlib.Path(GUILD_SPACES_FILE).exists():
|
||||
guild_spaces = read_pickle_file(GUILD_SPACES_FILE)
|
||||
assert isinstance(guild_spaces, dict), f"{GUILD_SPACES_FILE} in invalid format."
|
||||
guild_blocks = {}
|
||||
delete_keys = []
|
||||
for k, v in guild_spaces.items():
|
||||
try:
|
||||
guild_blocks[k] = gr.Interface.load(v, src="spaces")
|
||||
except ValueError:
|
||||
delete_keys.append(k)
|
||||
for k in delete_keys:
|
||||
del guild_spaces[k]
|
||||
else:
|
||||
guild_spaces: Dict[int, str] = {}
|
||||
guild_blocks: Dict[int, gr.Blocks] = {}
|
||||
|
||||
|
||||
HASHED_USERS_FILE = "users.pkl"
|
||||
|
||||
if pathlib.Path(HASHED_USERS_FILE).exists():
|
||||
hashed_users = read_pickle_file(HASHED_USERS_FILE)
|
||||
assert isinstance(hashed_users, list), f"{HASHED_USERS_FILE} in invalid format."
|
||||
else:
|
||||
hashed_users: List[str] = []
|
||||
|
||||
|
||||
@bot.event
|
||||
async def on_ready():
|
||||
print(f"Logged in as {bot.user}")
|
||||
print(f"Running in {len(bot.guilds)} servers...")
|
||||
|
||||
|
||||
async def run_prediction(space: gr.Blocks, *inputs):
|
||||
inputs = list(inputs)
|
||||
fn_index = 0
|
||||
processed_inputs = space.serialize_data(fn_index=fn_index, inputs=inputs)
|
||||
batch = space.dependencies[fn_index]["batch"]
|
||||
|
||||
if batch:
|
||||
processed_inputs = [[inp] for inp in processed_inputs]
|
||||
|
||||
outputs = await space.process_api(
|
||||
fn_index=fn_index, inputs=processed_inputs, request=None, state={}
|
||||
)
|
||||
outputs = outputs["data"]
|
||||
|
||||
if batch:
|
||||
outputs = [out[0] for out in outputs]
|
||||
|
||||
processed_outputs = space.deserialize_data(fn_index, outputs)
|
||||
processed_outputs = utils.resolve_singleton(processed_outputs)
|
||||
|
||||
return processed_outputs
|
||||
|
||||
|
||||
async def display_stats(message: discord.Message):
|
||||
await message.channel.send(
|
||||
f"Running in {len(bot.guilds)} servers\n"
|
||||
f"Total # of users: {len(hashed_users)}\n"
|
||||
f"------------------"
|
||||
)
|
||||
await message.channel.send(f"Most popular spaces:")
|
||||
# display the top 10 most frequently occurring strings and their counts
|
||||
spaces = guild_spaces.values()
|
||||
counts = Counter(spaces)
|
||||
for space, count in counts.most_common(10):
|
||||
await message.channel.send(f"- {space}: {count}")
|
||||
|
||||
|
||||
async def load_space(guild: discord.Guild, message: discord.Message, content: str):
|
||||
iframe_url = (
|
||||
requests.get(f"https://huggingface.co/api/spaces/{content}/host")
|
||||
.json()
|
||||
.get("host")
|
||||
)
|
||||
if iframe_url is None:
|
||||
return await message.channel.send(
|
||||
f"Space: {content} not found. If you'd like to make a prediction, enclose the inputs in quotation marks."
|
||||
)
|
||||
else:
|
||||
await message.channel.send(
|
||||
f"Loading Space: https://huggingface.co/spaces/{content}..."
|
||||
)
|
||||
interface = gr.Interface.load(content, src="spaces")
|
||||
guild_spaces[guild.id] = content
|
||||
guild_blocks[guild.id] = interface
|
||||
asyncio.create_task(update_pickle_file(guild_spaces, GUILD_SPACES_FILE))
|
||||
if len(content) > 32 - len(f"{bot.name} []"): # type: ignore
|
||||
nickname = content[: 32 - len(f"{bot.name} []") - 3] + "..." # type: ignore
|
||||
else:
|
||||
nickname = content
|
||||
nickname = f"{bot.name} [{nickname}]" # type: ignore
|
||||
await guild.me.edit(nick=nickname)
|
||||
await message.channel.send(
|
||||
"Ready to make predictions! Type in your inputs and enclose them in quotation marks."
|
||||
)
|
||||
|
||||
|
||||
async def disconnect_space(bot: commands.Bot, guild: discord.Guild):
|
||||
guild_spaces.pop(guild.id, None)
|
||||
guild_blocks.pop(guild.id, None)
|
||||
asyncio.create_task(update_pickle_file(guild_spaces, GUILD_SPACES_FILE))
|
||||
await guild.me.edit(nick=bot.name) # type: ignore
|
||||
|
||||
|
||||
async def make_prediction(guild: discord.Guild, message: discord.Message, content: str):
|
||||
if guild.id in guild_spaces:
|
||||
params = re.split(r' (?=")', content)
|
||||
params = [p.strip("'\"") for p in params]
|
||||
space = guild_blocks[guild.id]
|
||||
predictions = await run_prediction(space, *params)
|
||||
if isinstance(predictions, (tuple, list)):
|
||||
for p in predictions:
|
||||
await send_file_or_text(message.channel, p)
|
||||
else:
|
||||
await send_file_or_text(message.channel, predictions)
|
||||
return
|
||||
else:
|
||||
await message.channel.send(
|
||||
"No Space is currently running. Please type in the name of a Hugging Face Space name first, e.g. abidlabs/en2fr"
|
||||
)
|
||||
await guild.me.edit(nick=bot.name) # type: ignore
|
||||
|
||||
|
||||
@bot.event
|
||||
async def on_message(message: discord.Message):
|
||||
if message.author == bot.user:
|
||||
return
|
||||
h = hash_user_id(message.author.id)
|
||||
if h not in hashed_users:
|
||||
hashed_users.append(h)
|
||||
asyncio.create_task(update_pickle_file(hashed_users, HASHED_USERS_FILE))
|
||||
else:
|
||||
if message.content:
|
||||
content = remove_tags(message.content)
|
||||
guild = message.channel.guild
|
||||
assert guild, "Message not sent in a guild."
|
||||
|
||||
if content.strip() == "exit":
|
||||
await disconnect_space(bot, guild)
|
||||
elif content.strip() == "stats":
|
||||
await display_stats(message)
|
||||
elif content.startswith('"') or content.startswith("'"):
|
||||
await make_prediction(guild, message, content)
|
||||
else:
|
||||
await load_space(guild, message, content)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--token",
|
||||
type=str,
|
||||
help="API key for the Discord bot. You can set this to your Discord token if you'd like to make your own clone of the Gradio Bot.",
|
||||
required=False,
|
||||
default="",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.token.strip():
|
||||
discord_token = args.token
|
||||
bot.env = "staging" # type: ignore
|
||||
bot.name = "StagingBot" # type: ignore
|
||||
else:
|
||||
with open("secrets.json") as fp:
|
||||
discord_token = json.load(fp)["discord_token"]
|
||||
bot.env = "prod" # type: ignore
|
||||
bot.name = "GradioBot" # type: ignore
|
||||
|
||||
bot.run(discord_token)
|
@ -0,0 +1,41 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import pickle
|
||||
import hashlib
|
||||
import pathlib
|
||||
from typing import Dict, List
|
||||
|
||||
import discord
|
||||
|
||||
lock = asyncio.Lock()
|
||||
|
||||
|
||||
async def update_pickle_file(data: Dict | List, file_path: str):
|
||||
async with lock:
|
||||
with open(file_path, "wb") as fp:
|
||||
pickle.dump(data, fp)
|
||||
|
||||
|
||||
def read_pickle_file(file_path: str):
|
||||
with open(file_path, "rb") as fp:
|
||||
return pickle.load(fp)
|
||||
|
||||
|
||||
async def send_file_or_text(channel, file_or_text: str):
|
||||
# if the file exists, send as a file
|
||||
if pathlib.Path(str(file_or_text)).exists():
|
||||
with open(file_or_text, "rb") as f:
|
||||
return await channel.send(file=discord.File(f))
|
||||
else:
|
||||
return await channel.send(file_or_text)
|
||||
|
||||
|
||||
def remove_tags(content: str) -> str:
|
||||
content = content.replace("<@1040198143695933501>", "")
|
||||
content = content.replace("<@1057338428938788884>", "")
|
||||
return content.strip()
|
||||
|
||||
|
||||
def hash_user_id(user_id: int) -> str:
|
||||
return hashlib.sha256(str(user_id).encode("utf-8")).hexdigest()
|
@ -1,21 +0,0 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2023 pliny
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
@ -1,62 +0,0 @@
|
||||
from flask import Flask, request, jsonify
|
||||
import openai
|
||||
import logging
|
||||
from dalle3 import Dalle
|
||||
import os
|
||||
import gradio as gr
|
||||
import requests
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
import dotenv
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
dotenv.load_dotenv(".env")
|
||||
|
||||
# Initialize OpenAI API, INPUT YOUR OWN OPENAI KEY
|
||||
openai.api_key = os.getenv("OPENAI_API_KEY")
|
||||
|
||||
# Initialize DALLE3 API, INPUT YOUR OWN COOKIE
|
||||
cookie = os.getenv("DALLE_COOKIE")
|
||||
dalle = Dalle(cookie)
|
||||
|
||||
|
||||
def interpret_text_with_gpt(text):
|
||||
model_engine = "text-davinci-002"
|
||||
panel_instructions = "Create a comic panel where"
|
||||
refined_prompt = f"{panel_instructions} {text}"
|
||||
|
||||
response = openai.Completion.create(
|
||||
engine=model_engine,
|
||||
prompt=refined_prompt,
|
||||
max_tokens=100
|
||||
)
|
||||
|
||||
final_prompt = response.choices[0].text.strip()
|
||||
return final_prompt
|
||||
|
||||
def generate_images_with_dalle(refined_prompt):
|
||||
dalle.create(refined_prompt)
|
||||
urls = dalle.get_urls()
|
||||
return urls
|
||||
|
||||
def gradio_interface(text):
|
||||
refined_prompt = interpret_text_with_gpt(text)
|
||||
comic_panel_urls = generate_images_with_dalle(refined_prompt)
|
||||
|
||||
output = []
|
||||
for i, url in enumerate(comic_panel_urls):
|
||||
response = requests.get(url)
|
||||
img = Image.open(BytesIO(response.content))
|
||||
caption = f"Caption for panel {i+1}"
|
||||
output.append((img, caption))
|
||||
|
||||
return output
|
||||
|
||||
iface = gr.Interface(
|
||||
fn=gradio_interface,
|
||||
inputs=["text"],
|
||||
outputs=[gr.outputs.Image(type="pil", label="Comic Panels"), "text"]
|
||||
)
|
||||
|
||||
iface.launch()
|
Loading…
Reference in new issue