api clean up

NewTools
Kye 2 years ago
parent c591d305c4
commit 158ec1ab16

@ -13,10 +13,10 @@ from transformers import BlipProcessor, BlipForConditionalGeneration
from transformers import TrOCRProcessor, VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer from transformers import TrOCRProcessor, VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
from datasets import load_dataset from datasets import load_dataset
from PIL import Image from PIL import Image
import flask # import flask
from flask import request, jsonify # from flask import request, jsonify
import waitress import waitress
from flask_cors import CORS # from flask_cors import CORS
import io import io
from torchvision import transforms from torchvision import transforms
import torch import torch
@ -62,14 +62,14 @@ port = config["local_inference_endpoint"]["port"]
local_deployment = config["local_deployment"] local_deployment = config["local_deployment"]
device = config.get("device", "cuda:0") device = config.get("device", "cuda:0")
PROXY = None # PROXY = None
if config["proxy"]: # if config["proxy"]:
PROXY = { # PROXY = {
"https": config["proxy"], # "https": config["proxy"],
} # }
app = flask.Flask(__name__) # app = flask.Flask(__name__)
CORS(app) # CORS(app)
start = time.time() start = time.time()

@ -18,10 +18,10 @@ from diffusers.utils import load_image
from pydub import AudioSegment from pydub import AudioSegment
import threading import threading
from queue import Queue from queue import Queue
import flask # import flask
from flask import request, jsonify # from flask import request, jsonify
import waitress import waitress
from flask_cors import CORS, cross_origin # from flask_cors import CORS, cross_origin
from swarms.agents.workers.multi_modal_workers.omni_agent.get_token_ids import get_token_ids_for_task_parsing, get_token_ids_for_choose_model, count_tokens, get_max_context_length from swarms.agents.workers.multi_modal_workers.omni_agent.get_token_ids import get_token_ids_for_task_parsing, get_token_ids_for_choose_model, count_tokens, get_max_context_length
from huggingface_hub.inference_api import InferenceApi from huggingface_hub.inference_api import InferenceApi
from huggingface_hub.inference_api import ALL_TASKS from huggingface_hub.inference_api import ALL_TASKS
@ -1011,60 +1011,60 @@ def cli():
messages.append({"role": "assistant", "content": answer["message"]}) messages.append({"role": "assistant", "content": answer["message"]})
def server(): # def server():
http_listen = config["http_listen"] # http_listen = config["http_listen"]
host = http_listen["host"] # host = http_listen["host"]
port = http_listen["port"] # port = http_listen["port"]
app = flask.Flask(__name__, static_folder="public", static_url_path="/") # app = flask.Flask(__name__, static_folder="public", static_url_path="/")
app.config['DEBUG'] = False # app.config['DEBUG'] = False
CORS(app) # CORS(app)
@cross_origin() # @cross_origin()
@app.route('/tasks', methods=['POST']) # @app.route('/tasks', methods=['POST'])
def tasks(): # def tasks():
data = request.get_json() # data = request.get_json()
messages = data["messages"] # messages = data["messages"]
api_key = data.get("api_key", API_KEY) # api_key = data.get("api_key", API_KEY)
api_endpoint = data.get("api_endpoint", API_ENDPOINT) # api_endpoint = data.get("api_endpoint", API_ENDPOINT)
api_type = data.get("api_type", API_TYPE) # api_type = data.get("api_type", API_TYPE)
if api_key is None or api_type is None or api_endpoint is None: # if api_key is None or api_type is None or api_endpoint is None:
return jsonify({"error": "Please provide api_key, api_type and api_endpoint"}) # return jsonify({"error": "Please provide api_key, api_type and api_endpoint"})
response = chat_huggingface(messages, api_key, api_type, api_endpoint, return_planning=True) # response = chat_huggingface(messages, api_key, api_type, api_endpoint, return_planning=True)
return jsonify(response) # return jsonify(response)
@cross_origin() # @cross_origin()
@app.route('/results', methods=['POST']) # @app.route('/results', methods=['POST'])
def results(): # def results():
data = request.get_json() # data = request.get_json()
messages = data["messages"] # messages = data["messages"]
api_key = data.get("api_key", API_KEY) # api_key = data.get("api_key", API_KEY)
api_endpoint = data.get("api_endpoint", API_ENDPOINT) # api_endpoint = data.get("api_endpoint", API_ENDPOINT)
api_type = data.get("api_type", API_TYPE) # api_type = data.get("api_type", API_TYPE)
if api_key is None or api_type is None or api_endpoint is None: # if api_key is None or api_type is None or api_endpoint is None:
return jsonify({"error": "Please provide api_key, api_type and api_endpoint"}) # return jsonify({"error": "Please provide api_key, api_type and api_endpoint"})
response = chat_huggingface(messages, api_key, api_type, api_endpoint, return_results=True) # response = chat_huggingface(messages, api_key, api_type, api_endpoint, return_results=True)
return jsonify(response) # return jsonify(response)
@cross_origin() # @cross_origin()
@app.route('/hugginggpt', methods=['POST']) # @app.route('/hugginggpt', methods=['POST'])
def chat(): # def chat():
data = request.get_json() # data = request.get_json()
messages = data["messages"] # messages = data["messages"]
api_key = data.get("api_key", API_KEY) # api_key = data.get("api_key", API_KEY)
api_endpoint = data.get("api_endpoint", API_ENDPOINT) # api_endpoint = data.get("api_endpoint", API_ENDPOINT)
api_type = data.get("api_type", API_TYPE) # api_type = data.get("api_type", API_TYPE)
if api_key is None or api_type is None or api_endpoint is None: # if api_key is None or api_type is None or api_endpoint is None:
return jsonify({"error": "Please provide api_key, api_type and api_endpoint"}) # return jsonify({"error": "Please provide api_key, api_type and api_endpoint"})
response = chat_huggingface(messages, api_key, api_type, api_endpoint) # response = chat_huggingface(messages, api_key, api_type, api_endpoint)
return jsonify(response) # return jsonify(response)
print("server running...") # print("server running...")
waitress.serve(app, host=host, port=port) # waitress.serve(app, host=host, port=port)
if __name__ == "__main__": # if __name__ == "__main__":
if args.mode == "test": # if args.mode == "test":
test() # test()
elif args.mode == "server": # elif args.mode == "server":
server() # server()
elif args.mode == "cli": # elif args.mode == "cli":
cli() # cli()
Loading…
Cancel
Save