From 7b3178e2fdf68780de9ce4dcc8161faa8ab8cca3 Mon Sep 17 00:00:00 2001 From: Shiven Mian Date: Sat, 17 Feb 2024 05:30:05 -0800 Subject: [PATCH] feat: teach mode + accumulator fixes --- 01OS/01OS/clients/base_device.py | 5 +-- 01OS/01OS/server/server.py | 49 ++++++++++++++-------------- 01OS/01OS/server/teach.py | 56 ++++++++++++++++++++++++++++++-- 01OS/01OS/utils/accumulator.py | 6 +++- 4 files changed, 86 insertions(+), 30 deletions(-) diff --git a/01OS/01OS/clients/base_device.py b/01OS/01OS/clients/base_device.py index a8e75d9..7e6945b 100644 --- a/01OS/01OS/clients/base_device.py +++ b/01OS/01OS/clients/base_device.py @@ -329,5 +329,6 @@ class Device: listener.start() def start(self): - asyncio.run(self.start_async()) - p.terminate() \ No newline at end of file + if os.getenv('TEACH_MODE') == "False": + asyncio.run(self.start_async()) + p.terminate() \ No newline at end of file diff --git a/01OS/01OS/server/server.py b/01OS/01OS/server/server.py index 0902a63..553ee27 100644 --- a/01OS/01OS/server/server.py +++ b/01OS/01OS/server/server.py @@ -20,7 +20,7 @@ from .i import configure_interpreter from interpreter import interpreter import ngrok from ..utils.accumulator import Accumulator - +from .teach import teach from .utils.logs import setup_logging from .utils.logs import logger setup_logging() @@ -101,8 +101,6 @@ async def websocket_endpoint(websocket: WebSocket): send_task = asyncio.create_task(send_messages(websocket)) try: await asyncio.gather(receive_task, send_task) - except WebSocketDisconnect: - pass except Exception as e: logger.debug(traceback.format_exc()) logger.info(f"Connection lost. Error: {e}") @@ -290,27 +288,30 @@ from uvicorn import Config, Server if __name__ == "__main__": async def main(): - # Start listening - asyncio.create_task(listener()) - - # Start watching the kernel if it's your job to do that - if os.getenv('CODE_RUNNER') == "server": - asyncio.create_task(put_kernel_messages_into_queue(from_computer)) - - server_url = os.getenv('SERVER_URL') - if not server_url: - raise ValueError("The environment variable SERVER_URL is not set. Please set it to proceed.") - parsed_url = urllib.parse.urlparse(server_url) - - # Set up Ngrok - ngrok_auth_token = os.getenv('NGROK_AUTHTOKEN') - if ngrok_auth_token is not None: - await setup_ngrok(ngrok_auth_token, parsed_url) - - logger.info("Starting `server.py`...") + if os.getenv('TEACH_MODE') == "True": + teach() + else: + # Start listening + asyncio.create_task(listener()) + + # Start watching the kernel if it's your job to do that + if os.getenv('CODE_RUNNER') == "server": + asyncio.create_task(put_kernel_messages_into_queue(from_computer)) + + server_url = os.getenv('SERVER_URL') + if not server_url: + raise ValueError("The environment variable SERVER_URL is not set. Please set it to proceed.") + parsed_url = urllib.parse.urlparse(server_url) + + # Set up Ngrok + ngrok_auth_token = os.getenv('NGROK_AUTHTOKEN') + if ngrok_auth_token is not None: + await setup_ngrok(ngrok_auth_token, parsed_url) + + logger.info("Starting `server.py`...") - config = Config(app, host=parsed_url.hostname, port=parsed_url.port, lifespan='on') - server = Server(config) - await server.serve() + config = Config(app, host=parsed_url.hostname, port=parsed_url.port, lifespan='on') + server = Server(config) + await server.serve() asyncio.run(main()) \ No newline at end of file diff --git a/01OS/01OS/server/teach.py b/01OS/01OS/server/teach.py index 524375f..1da969b 100644 --- a/01OS/01OS/server/teach.py +++ b/01OS/01OS/server/teach.py @@ -1,9 +1,59 @@ from datetime import datetime +from .utils.logs import setup_logging, logger +import tkinter as tk +import tkinter.simpledialog +from interpreter import interpreter +from tkinter import messagebox +from ..utils.accumulator import Accumulator +import time +import os +setup_logging() +accumulator = Accumulator() class Skill: def __init__(self, name: str): self.skill_name = name self.steps = [] - - def teach(self, code: str): - self.steps.append(code) \ No newline at end of file + +def to_camel_case(text): + words = text.split() + camel_case_string = words[0].lower() + ''.join(word.title() for word in words[1:]) + return camel_case_string + +def generate_python_code(function_name, steps): + code_string = f'def {to_camel_case(function_name)}():\n' + code_string += f' """{function_name}"""\n' + code_string += f' print({steps})\n' + return code_string + +def teach(): + root = tk.Tk() + root.withdraw() + + skill_name = tkinter.simpledialog.askstring("Skill Name", "Please enter the name for the skill:") + skill = Skill(skill_name) + while True: + step = tkinter.simpledialog.askstring("Next Step", "Enter the next step (or 'end' to finish): ") + logger.info(f"Performing step: {step}") + if step == "end": + break + + for chunk in interpreter.chat(step, stream=True, display=False): + if "format" in chunk and chunk["format"] == "execution": + content = chunk["content"] + language = content["format"] + code = content["content"] + interpreter.computer.run(code, language) + time.sleep(0.05) + accumulator.accumulate(chunk) + + isCorrect = messagebox.askyesno("To Proceed?", "Did I do this step right?") + if isCorrect: + skill.steps.append(step) + + print(skill.skill_name, skill.steps) + python_code = generate_python_code(skill.skill_name, skill.steps) + SKILLS_DIR = os.path.dirname(__file__) + "/skills" + filename = os.path.join(SKILLS_DIR, f"{skill.skill_name.replace(' ', '_')}.py") + with open(filename, "w") as file: + file.write(python_code) diff --git a/01OS/01OS/utils/accumulator.py b/01OS/01OS/utils/accumulator.py index b6353cd..0129cd3 100644 --- a/01OS/01OS/utils/accumulator.py +++ b/01OS/01OS/utils/accumulator.py @@ -26,7 +26,11 @@ class Accumulator: if "content" not in self.message: self.message["content"] = chunk["content"] else: - self.message["content"] += chunk["content"] + if type(chunk["content"]) == dict: + # dict concatenation cannot happen, so we see if chunk is a dict + self.message["content"]["content"] += chunk["content"]["content"] + else: + self.message["content"] += chunk["content"] return None if "end" in chunk: