api clean up

pull/160/head
Kye 2 years ago
parent f055f051a3
commit 35eabd4cfd

@ -13,10 +13,10 @@ from transformers import BlipProcessor, BlipForConditionalGeneration
from transformers import TrOCRProcessor, VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
from datasets import load_dataset
from PIL import Image
import flask
from flask import request, jsonify
# import flask
# from flask import request, jsonify
import waitress
from flask_cors import CORS
# from flask_cors import CORS
import io
from torchvision import transforms
import torch
@ -62,14 +62,14 @@ port = config["local_inference_endpoint"]["port"]
local_deployment = config["local_deployment"]
device = config.get("device", "cuda:0")
PROXY = None
if config["proxy"]:
PROXY = {
"https": config["proxy"],
}
# PROXY = None
# if config["proxy"]:
# PROXY = {
# "https": config["proxy"],
# }
app = flask.Flask(__name__)
CORS(app)
# app = flask.Flask(__name__)
# CORS(app)
start = time.time()

@ -18,10 +18,10 @@ from diffusers.utils import load_image
from pydub import AudioSegment
import threading
from queue import Queue
import flask
from flask import request, jsonify
# import flask
# from flask import request, jsonify
import waitress
from flask_cors import CORS, cross_origin
# from flask_cors import CORS, cross_origin
from swarms.agents.workers.multi_modal_workers.omni_agent.get_token_ids import get_token_ids_for_task_parsing, get_token_ids_for_choose_model, count_tokens, get_max_context_length
from huggingface_hub.inference_api import InferenceApi
from huggingface_hub.inference_api import ALL_TASKS
@ -1011,60 +1011,60 @@ def cli():
messages.append({"role": "assistant", "content": answer["message"]})
def server():
http_listen = config["http_listen"]
host = http_listen["host"]
port = http_listen["port"]
app = flask.Flask(__name__, static_folder="public", static_url_path="/")
app.config['DEBUG'] = False
CORS(app)
@cross_origin()
@app.route('/tasks', methods=['POST'])
def tasks():
data = request.get_json()
messages = data["messages"]
api_key = data.get("api_key", API_KEY)
api_endpoint = data.get("api_endpoint", API_ENDPOINT)
api_type = data.get("api_type", API_TYPE)
if api_key is None or api_type is None or api_endpoint is None:
return jsonify({"error": "Please provide api_key, api_type and api_endpoint"})
response = chat_huggingface(messages, api_key, api_type, api_endpoint, return_planning=True)
return jsonify(response)
@cross_origin()
@app.route('/results', methods=['POST'])
def results():
data = request.get_json()
messages = data["messages"]
api_key = data.get("api_key", API_KEY)
api_endpoint = data.get("api_endpoint", API_ENDPOINT)
api_type = data.get("api_type", API_TYPE)
if api_key is None or api_type is None or api_endpoint is None:
return jsonify({"error": "Please provide api_key, api_type and api_endpoint"})
response = chat_huggingface(messages, api_key, api_type, api_endpoint, return_results=True)
return jsonify(response)
@cross_origin()
@app.route('/hugginggpt', methods=['POST'])
def chat():
data = request.get_json()
messages = data["messages"]
api_key = data.get("api_key", API_KEY)
api_endpoint = data.get("api_endpoint", API_ENDPOINT)
api_type = data.get("api_type", API_TYPE)
if api_key is None or api_type is None or api_endpoint is None:
return jsonify({"error": "Please provide api_key, api_type and api_endpoint"})
response = chat_huggingface(messages, api_key, api_type, api_endpoint)
return jsonify(response)
print("server running...")
waitress.serve(app, host=host, port=port)
if __name__ == "__main__":
if args.mode == "test":
test()
elif args.mode == "server":
server()
elif args.mode == "cli":
cli()
# def server():
# http_listen = config["http_listen"]
# host = http_listen["host"]
# port = http_listen["port"]
# app = flask.Flask(__name__, static_folder="public", static_url_path="/")
# app.config['DEBUG'] = False
# CORS(app)
# @cross_origin()
# @app.route('/tasks', methods=['POST'])
# def tasks():
# data = request.get_json()
# messages = data["messages"]
# api_key = data.get("api_key", API_KEY)
# api_endpoint = data.get("api_endpoint", API_ENDPOINT)
# api_type = data.get("api_type", API_TYPE)
# if api_key is None or api_type is None or api_endpoint is None:
# return jsonify({"error": "Please provide api_key, api_type and api_endpoint"})
# response = chat_huggingface(messages, api_key, api_type, api_endpoint, return_planning=True)
# return jsonify(response)
# @cross_origin()
# @app.route('/results', methods=['POST'])
# def results():
# data = request.get_json()
# messages = data["messages"]
# api_key = data.get("api_key", API_KEY)
# api_endpoint = data.get("api_endpoint", API_ENDPOINT)
# api_type = data.get("api_type", API_TYPE)
# if api_key is None or api_type is None or api_endpoint is None:
# return jsonify({"error": "Please provide api_key, api_type and api_endpoint"})
# response = chat_huggingface(messages, api_key, api_type, api_endpoint, return_results=True)
# return jsonify(response)
# @cross_origin()
# @app.route('/hugginggpt', methods=['POST'])
# def chat():
# data = request.get_json()
# messages = data["messages"]
# api_key = data.get("api_key", API_KEY)
# api_endpoint = data.get("api_endpoint", API_ENDPOINT)
# api_type = data.get("api_type", API_TYPE)
# if api_key is None or api_type is None or api_endpoint is None:
# return jsonify({"error": "Please provide api_key, api_type and api_endpoint"})
# response = chat_huggingface(messages, api_key, api_type, api_endpoint)
# return jsonify(response)
# print("server running...")
# waitress.serve(app, host=host, port=port)
# if __name__ == "__main__":
# if args.mode == "test":
# test()
# elif args.mode == "server":
# server()
# elif args.mode == "cli":
# cli()
Loading…
Cancel
Save