Fixed project structure and added basic pytest

pull/72/head
Tahir Siddique 1 year ago
parent a9c0759d0c
commit 2e5c3f59f1

@ -0,0 +1,10 @@
import os
import sys
import pytest
from fastapi.testclient import TestClient
from .server import app
@pytest.fixture
def client():
return TestClient(app)

@ -0,0 +1,13 @@
# test_main.py
import subprocess
import uuid
import pytest
from fastapi.testclient import TestClient
@pytest.mark.asyncio
def test_ping(client):
response = client.get("/ping")
assert response.status_code == 200
assert response.text == "pong"

@ -0,0 +1,9 @@
; Config for Pytest Runner.
; suppress Deprecation Warning and User Warning to not spam the interface, but check periodically
[pytest]
python_files = tests.py test_*.py
filterwarnings =
ignore::UserWarning
ignore::DeprecationWarning
log_cli = true
log_cli_level = INFO

@ -1,7 +0,0 @@
def hello_world():
return "Hello, World!"
# A test function to assert that hello_world() returns the expected string
def test_hello_world():
assert hello_world() == "Hello, World!"

@ -0,0 +1,4 @@
_archive
__pycache__
.idea

@ -0,0 +1,9 @@
The open-source language model computer.
```bash
pip install _01OS
```
```bash
01 # Runs the 01 server and client
```

@ -0,0 +1,351 @@
from dotenv import load_dotenv
load_dotenv() # take environment variables from .env.
import os
import asyncio
import threading
import os
import pyaudio
from starlette.websockets import WebSocket
from queue import Queue
from pynput import keyboard
import json
import traceback
import websockets
import queue
import pydub
import ast
from pydub import AudioSegment
from pydub.playback import play
import io
import time
import wave
import tempfile
from datetime import datetime
import cv2
import base64
from interpreter import interpreter # Just for code execution. Maybe we should let people do from interpreter.computer import run?
# In the future, I guess kernel watching code should be elsewhere? Somewhere server / client agnostic?
from ..server.utils.kernel import put_kernel_messages_into_queue
from ..server.utils.get_system_info import get_system_info
from ..server.utils.process_utils import kill_process_tree
from ..server.utils.logs import setup_logging
from ..server.utils.logs import logger
setup_logging()
os.environ["STT_RUNNER"] = "server"
os.environ["TTS_RUNNER"] = "server"
from ..utils.accumulator import Accumulator
accumulator = Accumulator()
# Configuration for Audio Recording
CHUNK = 1024 # Record in chunks of 1024 samples
FORMAT = pyaudio.paInt16 # 16 bits per sample
CHANNELS = 1 # Mono
RATE = 44100 # Sample rate
RECORDING = False # Flag to control recording state
SPACEBAR_PRESSED = False # Flag to track spacebar press state
# Camera configuration
CAMERA_ENABLED = os.getenv('CAMERA_ENABLED', False)
if type(CAMERA_ENABLED) == str:
CAMERA_ENABLED = (CAMERA_ENABLED.lower() == "true")
CAMERA_DEVICE_INDEX = int(os.getenv('CAMERA_DEVICE_INDEX', 0))
CAMERA_WARMUP_SECONDS = float(os.getenv('CAMERA_WARMUP_SECONDS', 0))
# Specify OS
current_platform = get_system_info()
# Initialize PyAudio
p = pyaudio.PyAudio()
send_queue = queue.Queue()
class Device:
def __init__(self):
self.pressed_keys = set()
self.captured_images = []
self.audiosegments = []
self.server_url = ""
def fetch_image_from_camera(self, camera_index=CAMERA_DEVICE_INDEX):
"""Captures an image from the specified camera device and saves it to a temporary file. Adds the image to the captured_images list."""
image_path = None
cap = cv2.VideoCapture(camera_index)
ret, frame = cap.read() # Capture a single frame to initialize the camera
if CAMERA_WARMUP_SECONDS > 0:
# Allow camera to warm up, then snap a picture again
# This is a workaround for some cameras that don't return a properly exposed
# picture immediately when they are first turned on
time.sleep(CAMERA_WARMUP_SECONDS)
ret, frame = cap.read()
if ret:
temp_dir = tempfile.gettempdir()
image_path = os.path.join(temp_dir, f"01_photo_{datetime.now().strftime('%Y%m%d%H%M%S%f')}.png")
self.captured_images.append(image_path)
cv2.imwrite(image_path, frame)
logger.info(f"Camera image captured to {image_path}")
logger.info(f"You now have {len(self.captured_images)} images which will be sent along with your next audio message.")
else:
logger.error(f"Error: Couldn't capture an image from camera ({camera_index})")
cap.release()
return image_path
def encode_image_to_base64(self, image_path):
"""Encodes an image file to a base64 string."""
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode('utf-8')
def add_image_to_send_queue(self, image_path):
"""Encodes an image and adds an LMC message to the send queue with the image data."""
base64_image = self.encode_image_to_base64(image_path)
image_message = {
"role": "user",
"type": "image",
"format": "base64.png",
"content": base64_image
}
send_queue.put(image_message)
# Delete the image file from the file system after sending it
os.remove(image_path)
def queue_all_captured_images(self):
"""Queues all captured images to be sent."""
for image_path in self.captured_images:
self.add_image_to_send_queue(image_path)
self.captured_images.clear() # Clear the list after sending
async def play_audiosegments(self):
"""Plays them sequentially."""
while True:
try:
for audio in self.audiosegments:
play(audio)
self.audiosegments.remove(audio)
await asyncio.sleep(0.1)
except asyncio.exceptions.CancelledError:
# This happens once at the start?
pass
except:
logger.info(traceback.format_exc())
def record_audio(self):
if os.getenv('STT_RUNNER') == "server":
# STT will happen on the server. we're sending audio.
send_queue.put({"role": "user", "type": "audio", "format": "bytes.wav", "start": True})
elif os.getenv('STT_RUNNER') == "client":
# STT will happen here, on the client. we're sending text.
send_queue.put({"role": "user", "type": "message", "start": True})
else:
raise Exception("STT_RUNNER must be set to either 'client' or 'server'.")
"""Record audio from the microphone and add it to the queue."""
stream = p.open(format=FORMAT, channels=CHANNELS, rate=RATE, input=True, frames_per_buffer=CHUNK)
logger.info("Recording started...")
global RECORDING
# Create a temporary WAV file to store the audio data
temp_dir = tempfile.gettempdir()
wav_path = os.path.join(temp_dir, f"audio_{datetime.now().strftime('%Y%m%d%H%M%S%f')}.wav")
wav_file = wave.open(wav_path, 'wb')
wav_file.setnchannels(CHANNELS)
wav_file.setsampwidth(p.get_sample_size(FORMAT))
wav_file.setframerate(RATE)
while RECORDING:
data = stream.read(CHUNK, exception_on_overflow=False)
wav_file.writeframes(data)
wav_file.close()
stream.stop_stream()
stream.close()
logger.info("Recording stopped.")
duration = wav_file.getnframes() / RATE
if duration < 0.3:
# Just pressed it. Send stop message
if os.getenv('STT_RUNNER') == "client":
send_queue.put({"role": "user", "type": "message", "content": "stop"})
send_queue.put({"role": "user", "type": "message", "end": True})
else:
send_queue.put({"role": "user", "type": "audio", "format": "bytes.wav", "content": ""})
send_queue.put({"role": "user", "type": "audio", "format": "bytes.wav", "end": True})
else:
self.queue_all_captured_images()
if os.getenv('STT_RUNNER') == "client":
# THIS DOES NOT WORK. We moved to this very cool stt_service, llm_service
# way of doing things. stt_wav is not a thing anymore. Needs work to work
# Run stt then send text
text = stt_wav(wav_path)
logger.debug(f"STT result: {text}")
send_queue.put({"role": "user", "type": "message", "content": text})
send_queue.put({"role": "user", "type": "message", "end": True})
else:
# Stream audio
with open(wav_path, 'rb') as audio_file:
byte_data = audio_file.read(CHUNK)
while byte_data:
send_queue.put(byte_data)
byte_data = audio_file.read(CHUNK)
send_queue.put({"role": "user", "type": "audio", "format": "bytes.wav", "end": True})
if os.path.exists(wav_path):
os.remove(wav_path)
def toggle_recording(self, state):
"""Toggle the recording state."""
global RECORDING, SPACEBAR_PRESSED
if state and not SPACEBAR_PRESSED:
SPACEBAR_PRESSED = True
if not RECORDING:
RECORDING = True
threading.Thread(target=self.record_audio).start()
elif not state and SPACEBAR_PRESSED:
SPACEBAR_PRESSED = False
RECORDING = False
def on_press(self, key):
"""Detect spacebar press and Ctrl+C combination."""
self.pressed_keys.add(key) # Add the pressed key to the set
if keyboard.Key.space in self.pressed_keys:
self.toggle_recording(True)
elif {keyboard.Key.ctrl, keyboard.KeyCode.from_char('c')} <= self.pressed_keys:
logger.info("Ctrl+C pressed. Exiting...")
kill_process_tree()
os._exit(0)
def on_release(self, key):
"""Detect spacebar release and 'c' key press for camera, and handle key release."""
self.pressed_keys.discard(key) # Remove the released key from the key press tracking set
if key == keyboard.Key.space:
self.toggle_recording(False)
elif CAMERA_ENABLED and key == keyboard.KeyCode.from_char('c'):
self.fetch_image_from_camera()
async def message_sender(self, websocket):
while True:
message = await asyncio.get_event_loop().run_in_executor(None, send_queue.get)
if isinstance(message, bytes):
await websocket.send(message)
else:
await websocket.send(json.dumps(message))
send_queue.task_done()
await asyncio.sleep(0.01)
async def websocket_communication(self, WS_URL):
while True:
try:
async with websockets.connect(WS_URL) as websocket:
if CAMERA_ENABLED:
logger.info("Press the spacebar to start/stop recording. Press 'c' to capture an image from the camera. Press CTRL-C to exit.")
else:
logger.info("Press the spacebar to start/stop recording. Press CTRL-C to exit.")
asyncio.create_task(self.message_sender(websocket))
while True:
await asyncio.sleep(0.01)
chunk = await websocket.recv()
logger.debug(f"Got this message from the server: {type(chunk)} {chunk}")
if type(chunk) == str:
chunk = json.loads(chunk)
message = accumulator.accumulate(chunk)
if message == None:
# Will be None until we have a full message ready
continue
# At this point, we have our message
if message["type"] == "audio" and message["format"].startswith("bytes"):
# Convert bytes to audio file
audio_bytes = message["content"]
# Create an AudioSegment instance with the raw data
audio = AudioSegment(
# raw audio data (bytes)
data=audio_bytes,
# signed 16-bit little-endian format
sample_width=2,
# 16,000 Hz frame rate
frame_rate=16000,
# mono sound
channels=1
)
self.audiosegments.append(audio)
# Run the code if that's the client's job
if os.getenv('CODE_RUNNER') == "client":
if message["type"] == "code" and "end" in message:
language = message["format"]
code = message["content"]
result = interpreter.computer.run(language, code)
send_queue.put(result)
except:
logger.debug(traceback.format_exc())
logger.info(f"Connecting to `{WS_URL}`...")
await asyncio.sleep(2)
async def start_async(self):
# Configuration for WebSocket
WS_URL = f"ws://{self.server_url}"
# Start the WebSocket communication
asyncio.create_task(self.websocket_communication(WS_URL))
# Start watching the kernel if it's your job to do that
if os.getenv('CODE_RUNNER') == "client":
asyncio.create_task(put_kernel_messages_into_queue(send_queue))
asyncio.create_task(self.play_audiosegments())
# If Raspberry Pi, add the button listener, otherwise use the spacebar
if current_platform.startswith("raspberry-pi"):
logger.info("Raspberry Pi detected, using button on GPIO pin 15")
# Use GPIO pin 15
pindef = ["gpiochip4", "15"] # gpiofind PIN15
print("PINDEF", pindef)
# HACK: needs passwordless sudo
process = await asyncio.create_subprocess_exec("sudo", "gpiomon", "-brf", *pindef, stdout=asyncio.subprocess.PIPE)
while True:
line = await process.stdout.readline()
if line:
line = line.decode().strip()
if "FALLING" in line:
self.toggle_recording(False)
elif "RISING" in line:
self.toggle_recording(True)
else:
break
else:
# Keyboard listener for spacebar press/release
listener = keyboard.Listener(on_press=self.on_press, on_release=self.on_release)
listener.start()
def start(self):
if os.getenv('TEACH_MODE') != "True":
asyncio.run(self.start_async())
p.terminate()

@ -0,0 +1,11 @@
# ESP32 Playback
To set up audio recording + playback on the ESP32 (M5 Atom), do the following:
1. Open Arduino IDE, and open the `client/client.ino` file
2. Go to Tools -> Board -> Boards Manager, search "esp32", then install the boards by Arduino and Espressif
3. Go to Tools -> Manage Libraries, then install the following:
- M5Atom by M5Stack
- WebSockets by Markus Sattler
4. The board needs to connect to WiFi. Once you flash, connect to ESP32 wifi "captive" which will get wifi details. Once it connects, it will ask you to enter _01OS server address in the format "domain.com:port" or "ip:port". Once its able to connect you can use the device.
5. To flash the .ino to the board, connect the board to the USB port, select the port from the dropdown on the IDE, then select the M5Atom board (or M5Stack-ATOM if you have that). Click on upload to flash the board.

@ -0,0 +1,675 @@
#include <driver/i2s.h>
#include <M5Atom.h>
#include <Arduino.h> //not needed in the arduino ide
#include <AsyncTCP.h> //https://github.com/me-no-dev/AsyncTCP using the latest dev version from @me-no-dev
#include <DNSServer.h>
#include <ESPAsyncWebServer.h> //https://github.com/me-no-dev/ESPAsyncWebServer using the latest dev version from @me-no-dev
#include <esp_wifi.h> //Used for mpdu_rx_disable android workaround
#include <ArduinoHttpClient.h>
#include <WiFi.h>
#include <WiFiClient.h>
#include <WiFiMulti.h>
#include <WiFiClientSecure.h>
#include <WebSocketsClient.h>
String server_domain = "";
int server_port = 8000;
// ----------------------- START OF WIFI CAPTIVE PORTAL -------------------
// Pre reading on the fundamentals of captive portals https://textslashplain.com/2022/06/24/captive-portals/
const char *ssid = "captive"; // FYI The SSID can't have a space in it.
// const char * password = "12345678"; //Atleast 8 chars
const char *password = NULL; // no password
#define MAX_CLIENTS 4 // ESP32 supports up to 10 but I have not tested it yet
#define WIFI_CHANNEL 6 // 2.4ghz channel 6 https://en.wikipedia.org/wiki/List_of_WLAN_channels#2.4_GHz_(802.11b/g/n/ax)
const IPAddress localIP(4, 3, 2, 1); // the IP address the web server, Samsung requires the IP to be in public space
const IPAddress gatewayIP(4, 3, 2, 1); // IP address of the network should be the same as the local IP for captive portals
const IPAddress subnetMask(255, 255, 255, 0); // no need to change: https://avinetworks.com/glossary/subnet-mask/
const String localIPURL = "http://4.3.2.1"; // a string version of the local IP with http, used for redirecting clients to your webpage
// Number of milliseconds to wait without receiving any data before we give up
const int kNetworkTimeout = 30 * 1000;
// Number of milliseconds to wait if no data is available before trying again
const int kNetworkDelay = 1000;
String generateHTMLWithSSIDs()
{
String html = "<!DOCTYPE html><html><body><h2>Select Wi-Fi Network</h2><form action='/submit' method='POST'><label for='ssid'>SSID:</label><select id='ssid' name='ssid'>";
int n = WiFi.scanComplete();
for (int i = 0; i < n; ++i)
{
html += "<option value='" + WiFi.SSID(i) + "'>" + WiFi.SSID(i) + "</option>";
}
html += "</select><br><label for='password'>Password:</label><input type='password' id='password' name='password'><br><input type='submit' value='Connect'></form></body></html>";
return html;
}
const char index_html[] PROGMEM = R"=====(
<!DOCTYPE html>
<html>
<head>
<title>WiFi Setup</title>
<style>
body {background-color:#06cc13;}
h1 {color: white;}
h2 {color: white;}
</style>
<meta name="viewport" content="width=device-width, initial-scale=1.0">
</head>
<body>
<h1>WiFi Setup</h1>
<form action="/submit" method="post">
<label for="ssid">SSID:</label><br>
<input type="text" id="ssid" name="ssid"><br>
<label for="password">Password:</label><br>
<input type="password" id="password" name="password"><br><br>
<input type="submit" value="Connect">
</form>
</body>
</html>
)=====";
const char post_connected_html[] PROGMEM = R"=====(
<!DOCTYPE html>
<html>
<head>
<title>_01OS Setup</title>
<style>
body {background-color:white;}
h1 {color: black;}
h2 {color: black;}
</style>
<meta name="viewport" content="width=device-width, initial-scale=1.0">
</head>
<body>
<h1>_01OS Setup</h1>
<form action="/submit__01OS" method="post">
<label for="server_address">_01OS Server Address:</label><br>
<input type="text" id="server_address" name="server_address"><br>
<input type="submit" value="Connect">
</form>
</body>
</html>
)=====";
DNSServer dnsServer;
AsyncWebServer server(80);
void setUpDNSServer(DNSServer &dnsServer, const IPAddress &localIP)
{
// Define the DNS interval in milliseconds between processing DNS requests
#define DNS_INTERVAL 30
// Set the TTL for DNS response and start the DNS server
dnsServer.setTTL(3600);
dnsServer.start(53, "*", localIP);
}
void startSoftAccessPoint(const char *ssid, const char *password, const IPAddress &localIP, const IPAddress &gatewayIP)
{
// Define the maximum number of clients that can connect to the server
#define MAX_CLIENTS 4
// Define the WiFi channel to be used (channel 6 in this case)
#define WIFI_CHANNEL 6
// Set the WiFi mode to access point and station
// WiFi.mode(WIFI_MODE_AP);
// Define the subnet mask for the WiFi network
const IPAddress subnetMask(255, 255, 255, 0);
// Configure the soft access point with a specific IP and subnet mask
WiFi.softAPConfig(localIP, gatewayIP, subnetMask);
// Start the soft access point with the given ssid, password, channel, max number of clients
WiFi.softAP(ssid, password, WIFI_CHANNEL, 0, MAX_CLIENTS);
// Disable AMPDU RX on the ESP32 WiFi to fix a bug on Android
esp_wifi_stop();
esp_wifi_deinit();
wifi_init_config_t my_config = WIFI_INIT_CONFIG_DEFAULT();
my_config.ampdu_rx_enable = false;
esp_wifi_init(&my_config);
esp_wifi_start();
vTaskDelay(100 / portTICK_PERIOD_MS); // Add a small delay
}
void connectToWifi(String ssid, String password)
{
WiFi.begin(ssid.c_str(), password.c_str());
// Wait for connection to establish
int attempts = 0;
while (WiFi.status() != WL_CONNECTED && attempts < 20)
{
delay(1000);
Serial.print(".");
attempts++;
}
if (WiFi.status() == WL_CONNECTED)
{
Serial.println("Connected to Wi-Fi");
}
else
{
Serial.println("Failed to connect to Wi-Fi. Check credentials.");
}
}
bool connectTo_01OS(String server_address)
{
int err = 0;
String domain;
String portStr;
if (server_address.indexOf(":") != -1) {
domain = server_address.substring(0, server_address.indexOf(':'));
portStr = server_address.substring(server_address.indexOf(':') + 1);
} else {
domain = server_address;
portStr = ""; // or any default value you want to assign
}
int port = 0; // Default port value
if (portStr.length() > 0) {
port = portStr.toInt();
}
WiFiClient c;
HttpClient http(c, domain.c_str(), port);
Serial.println("Connecting to _01OS at " + domain + ":" + port + "/ping");
err = http.get("/ping");
// err = http.get("arduino.cc", "/");
bool connectionSuccess = false;
if (err == 0)
{
Serial.println("Started the ping request");
err = http.responseStatusCode();
if (err >= 0)
{
Serial.print("Got status code: ");
Serial.println(err);
if (err == 200)
{
server_domain = domain;
server_port = port;
connectionSuccess = true;
}
err = http.skipResponseHeaders();
if (err >= 0)
{
int bodyLen = http.contentLength();
Serial.print("Content length is: ");
Serial.println(bodyLen);
Serial.println();
Serial.println("Body returned follows:");
// Now we've got to the body, so we can print it out
unsigned long timeoutStart = millis();
char c;
// Whilst we haven't timed out & haven't reached the end of the body
while ((http.connected() || http.available()) &&
((millis() - timeoutStart) < kNetworkTimeout))
{
if (http.available())
{
c = http.read();
// Print out this character
Serial.print(c);
bodyLen--;
// We read something, reset the timeout counter
timeoutStart = millis();
}
else
{
// We haven't got any data, so let's pause to allow some to
// arrive
delay(kNetworkDelay);
}
}
}
else
{
Serial.print("Failed to skip response headers: ");
Serial.println(err);
}
}
else
{
Serial.print("Getting response failed: ");
Serial.println(err);
}
}
else
{
Serial.print("Connect failed: ");
Serial.println(err);
}
return connectionSuccess;
}
void setUpWebserver(AsyncWebServer &server, const IPAddress &localIP)
{
//======================== Webserver ========================
// WARNING IOS (and maybe macos) WILL NOT POP UP IF IT CONTAINS THE WORD "Success" https://www.esp8266.com/viewtopic.php?f=34&t=4398
// SAFARI (IOS) IS STUPID, G-ZIPPED FILES CAN'T END IN .GZ https://github.com/homieiot/homie-esp8266/issues/476 this is fixed by the webserver serve static function.
// SAFARI (IOS) there is a 128KB limit to the size of the HTML. The HTML can reference external resources/images that bring the total over 128KB
// SAFARI (IOS) popup browser has some severe limitations (javascript disabled, cookies disabled)
// Required
server.on("/connecttest.txt", [](AsyncWebServerRequest *request)
{ request->redirect("http://logout.net"); }); // windows 11 captive portal workaround
server.on("/wpad.dat", [](AsyncWebServerRequest *request)
{ request->send(404); }); // Honestly don't understand what this is but a 404 stops win 10 keep calling this repeatedly and panicking the esp32 :)
// Background responses: Probably not all are Required, but some are. Others might speed things up?
// A Tier (commonly used by modern systems)
server.on("/generate_204", [](AsyncWebServerRequest *request)
{ request->redirect(localIPURL); }); // android captive portal redirect
server.on("/redirect", [](AsyncWebServerRequest *request)
{ request->redirect(localIPURL); }); // microsoft redirect
server.on("/hotspot-detect.html", [](AsyncWebServerRequest *request)
{ request->redirect(localIPURL); }); // apple call home
server.on("/canonical.html", [](AsyncWebServerRequest *request)
{ request->redirect(localIPURL); }); // firefox captive portal call home
server.on("/success.txt", [](AsyncWebServerRequest *request)
{ request->send(200); }); // firefox captive portal call home
server.on("/ncsi.txt", [](AsyncWebServerRequest *request)
{ request->redirect(localIPURL); }); // windows call home
// B Tier (uncommon)
// server.on("/chrome-variations/seed",[](AsyncWebServerRequest *request){request->send(200);}); //chrome captive portal call home
// server.on("/service/update2/json",[](AsyncWebServerRequest *request){request->send(200);}); //firefox?
// server.on("/chat",[](AsyncWebServerRequest *request){request->send(404);}); //No stop asking Whatsapp, there is no internet connection
// server.on("/startpage",[](AsyncWebServerRequest *request){request->redirect(localIPURL);});
// return 404 to webpage icon
server.on("/favicon.ico", [](AsyncWebServerRequest *request)
{ request->send(404); }); // webpage icon
// Serve Basic HTML Page
server.on("/", HTTP_ANY, [](AsyncWebServerRequest *request)
{
String htmlContent = index_html;
Serial.printf("wifi scan complete: %d . WIFI_SCAN_RUNNING: %d", WiFi.scanComplete(), WIFI_SCAN_RUNNING);
if(WiFi.scanComplete() > 0) {
// Scan complete, process results
Serial.println("done scanning wifi");
htmlContent = generateHTMLWithSSIDs();
// WiFi.scanNetworks(true); // Start a new scan in async mode
}
AsyncWebServerResponse *response = request->beginResponse(200, "text/html", htmlContent);
response->addHeader("Cache-Control", "public,max-age=31536000"); // save this file to cache for 1 year (unless you refresh)
request->send(response);
Serial.println("Served Basic HTML Page"); });
// the catch all
server.onNotFound([](AsyncWebServerRequest *request)
{
request->redirect(localIPURL);
Serial.print("onnotfound ");
Serial.print(request->host()); // This gives some insight into whatever was being requested on the serial monitor
Serial.print(" ");
Serial.print(request->url());
Serial.print(" sent redirect to " + localIPURL + "\n"); });
server.on("/submit", HTTP_POST, [](AsyncWebServerRequest *request)
{
String ssid;
String password;
// Check if SSID parameter exists and assign it
if(request->hasParam("ssid", true)) {
ssid = request->getParam("ssid", true)->value();
}
// Check if Password parameter exists and assign it
if(request->hasParam("password", true)) {
password = request->getParam("password", true)->value();
}
// Attempt to connect to the Wi-Fi network with these credentials
connectToWifi(ssid, password);
// Redirect user or send a response back
if (WiFi.status() == WL_CONNECTED) {
String htmlContent = post_connected_html;
AsyncWebServerResponse *response = request->beginResponse(200, "text/html", htmlContent);
response->addHeader("Cache-Control", "public,max-age=31536000"); // save this file to cache for 1 year (unless you refresh)
request->send(response);
Serial.println("Served Post connection HTML Page");
} else {
request->send(200, "text/plain", "Failed to connect to " + ssid);
} });
server.on("/submit__01OS", HTTP_POST, [](AsyncWebServerRequest *request)
{
String server_address;
// Check if SSID parameter exists and assign it
if(request->hasParam("server_address", true)) {
server_address = request->getParam("server_address", true)->value();
}
// Attempt to connect to the Wi-Fi network with these credentials
bool connectedToServer = connectTo_01OS(server_address);
// Redirect user or send a response back
String connectionMessage;
if (connectedToServer)
{
connectionMessage = "Connected to _01OS " + server_address;
}
else
{
connectionMessage = "Couldn't connect to _01OS " + server_address;
}
request->send(200, "text/plain", connectionMessage); });
}
// ----------------------- END OF WIFI CAPTIVE PORTAL -------------------
// ----------------------- START OF PLAYBACK -------------------
#define CONFIG_I2S_BCK_PIN 19
#define CONFIG_I2S_LRCK_PIN 33
#define CONFIG_I2S_DATA_PIN 22
#define CONFIG_I2S_DATA_IN_PIN 23
#define SPEAKER_I2S_NUMBER I2S_NUM_0
#define MODE_MIC 0
#define MODE_SPK 1
#define DATA_SIZE 1024
uint8_t microphonedata0[1024 * 10];
uint8_t speakerdata0[1024 * 1];
int speaker_offset;
int data_offset;
bool recording = false;
WebSocketsClient webSocket;
class ButtonChecker
{
public:
void loop()
{
lastTickState = thisTickState;
thisTickState = M5.Btn.isPressed() != 0;
}
bool justPressed()
{
return thisTickState && !lastTickState;
}
bool justReleased()
{
return !thisTickState && lastTickState;
}
private:
bool lastTickState = false;
bool thisTickState = false;
};
ButtonChecker button = ButtonChecker();
void InitI2SSpeakerOrMic(int mode)
{
Serial.printf("InitI2sSpeakerOrMic %d\n", mode);
esp_err_t err = ESP_OK;
i2s_driver_uninstall(SPEAKER_I2S_NUMBER);
i2s_config_t i2s_config = {
.mode = (i2s_mode_t)(I2S_MODE_MASTER),
.sample_rate = 16000,
.bits_per_sample =
I2S_BITS_PER_SAMPLE_16BIT, // is fixed at 12bit, stereo, MSB
.channel_format = I2S_CHANNEL_FMT_ALL_RIGHT,
#if ESP_IDF_VERSION > ESP_IDF_VERSION_VAL(4, 1, 0)
.communication_format =
I2S_COMM_FORMAT_STAND_I2S, // Set the format of the communication.
#else // 设置通讯格式
.communication_format = I2S_COMM_FORMAT_I2S,
#endif
.intr_alloc_flags = ESP_INTR_FLAG_LEVEL1,
.dma_buf_count = 6,
.dma_buf_len = 60,
};
if (mode == MODE_MIC)
{
i2s_config.mode =
(i2s_mode_t)(I2S_MODE_MASTER | I2S_MODE_RX | I2S_MODE_PDM);
}
else
{
i2s_config.mode = (i2s_mode_t)(I2S_MODE_MASTER | I2S_MODE_TX);
i2s_config.use_apll = false;
i2s_config.tx_desc_auto_clear = true;
}
err += i2s_driver_install(SPEAKER_I2S_NUMBER, &i2s_config, 0, NULL);
i2s_pin_config_t tx_pin_config;
#if (ESP_IDF_VERSION > ESP_IDF_VERSION_VAL(4, 3, 0))
tx_pin_config.mck_io_num = I2S_PIN_NO_CHANGE;
#endif
tx_pin_config.bck_io_num = CONFIG_I2S_BCK_PIN;
tx_pin_config.ws_io_num = CONFIG_I2S_LRCK_PIN;
tx_pin_config.data_out_num = CONFIG_I2S_DATA_PIN;
tx_pin_config.data_in_num = CONFIG_I2S_DATA_IN_PIN;
// Serial.println("Init i2s_set_pin");
err += i2s_set_pin(SPEAKER_I2S_NUMBER, &tx_pin_config);
// Serial.println("Init i2s_set_clk");
err += i2s_set_clk(SPEAKER_I2S_NUMBER, 16000, I2S_BITS_PER_SAMPLE_16BIT,
I2S_CHANNEL_MONO);
}
void speaker_play(uint8_t *payload, uint32_t len)
{
Serial.printf("received %lu bytes", len);
size_t bytes_written;
InitI2SSpeakerOrMic(MODE_SPK);
i2s_write(SPEAKER_I2S_NUMBER, payload, len,
&bytes_written, portMAX_DELAY);
}
void webSocketEvent(WStype_t type, uint8_t *payload, size_t length)
{
switch (type)
{
case WStype_DISCONNECTED:
Serial.printf("[WSc] Disconnected!\n");
break;
case WStype_CONNECTED:
Serial.printf("[WSc] Connected to url: %s\n", payload);
// send message to server when Connected
break;
case WStype_TEXT:
Serial.printf("[WSc] get text: %s\n", payload);
{
std::string str(payload, payload + length);
bool isAudio = str.find("\"audio\"") != std::string::npos;
if (isAudio && str.find("\"start\"") != std::string::npos)
{
Serial.println("start playback");
speaker_offset = 0;
InitI2SSpeakerOrMic(MODE_SPK);
}
else if (isAudio && str.find("\"end\"") != std::string::npos)
{
Serial.println("end playback");
// speaker_play(speakerdata0, speaker_offset);
// speaker_offset = 0;
}
}
// send message to server
// webSocket.sendTXT("message here");
break;
case WStype_BIN:
Serial.printf("[WSc] get binary length: %u\n", length);
memcpy(speakerdata0 + speaker_offset, payload, length);
speaker_offset += length;
size_t bytes_written;
i2s_write(SPEAKER_I2S_NUMBER, speakerdata0, speaker_offset, &bytes_written, portMAX_DELAY);
speaker_offset = 0;
// send data to server
// webSocket.sendBIN(payload, length);
break;
case WStype_ERROR:
case WStype_FRAGMENT_TEXT_START:
case WStype_FRAGMENT_BIN_START:
case WStype_FRAGMENT:
case WStype_FRAGMENT_FIN:
break;
}
}
void websocket_setup(String server_domain, int port)
{
if (WiFi.status() != WL_CONNECTED)
{
Serial.println("Not connected to WiFi. Abandoning setup websocket");
return;
}
Serial.println("connected to WiFi");
webSocket.begin(server_domain, port, "/");
webSocket.onEvent(webSocketEvent);
// webSocket.setAuthorization("user", "Password");
webSocket.setReconnectInterval(5000);
}
void flush_microphone()
{
Serial.printf("[microphone] flushing %d bytes of data\n", data_offset);
if (data_offset == 0)
return;
webSocket.sendBIN(microphonedata0, data_offset);
data_offset = 0;
}
// ----------------------- END OF PLAYBACK -------------------
bool hasSetupWebsocket = false;
void setup()
{
// Set the transmit buffer size for the Serial object and start it with a baud rate of 115200.
Serial.setTxBufferSize(1024);
Serial.begin(115200);
// Wait for the Serial object to become available.
while (!Serial)
;
WiFi.mode(WIFI_AP_STA);
// Print a welcome message to the Serial port.
Serial.println("\n\nCaptive Test, V0.5.0 compiled " __DATE__ " " __TIME__ " by CD_FER"); //__DATE__ is provided by the platformio ide
Serial.printf("%s-%d\n\r", ESP.getChipModel(), ESP.getChipRevision());
startSoftAccessPoint(ssid, password, localIP, gatewayIP);
setUpDNSServer(dnsServer, localIP);
WiFi.scanNetworks(true);
setUpWebserver(server, localIP);
server.begin();
Serial.print("\n");
Serial.print("Startup Time:"); // should be somewhere between 270-350 for Generic ESP32 (D0WDQ6 chip, can have a higher startup time on first boot)
Serial.println(millis());
Serial.print("\n");
M5.begin(true, false, true);
M5.dis.drawpix(0, CRGB(128, 128, 0));
}
void loop()
{
dnsServer.processNextRequest(); // I call this atleast every 10ms in my other projects (can be higher but I haven't tested it for stability)
delay(DNS_INTERVAL); // seems to help with stability, if you are doing other things in the loop this may not be needed
// Check WiFi connection status
if (WiFi.status() == WL_CONNECTED && !hasSetupWebsocket)
{
if (server_domain != "")
{
Serial.println("Setting up websocket to _01OS " + server_domain + ":" + server_port);
websocket_setup(server_domain, server_port);
InitI2SSpeakerOrMic(MODE_SPK);
hasSetupWebsocket = true;
Serial.println("Websocket connection flow completed");
}
else
{
Serial.println("No valid _01OS server address yet...");
}
// If connected, you might want to do something, like printing the IP address
// Serial.println("Connected to WiFi!");
// Serial.println("IP Address: " + WiFi.localIP().toString());
// Serial.println("SSID " + WiFi.SSID());
}
if (WiFi.status() == WL_CONNECTED && hasSetupWebsocket)
{
button.loop();
if (button.justPressed())
{
Serial.println("Recording...");
webSocket.sendTXT("{\"role\": \"user\", \"type\": \"audio\", \"format\": \"bytes.raw\", \"start\": true}");
InitI2SSpeakerOrMic(MODE_MIC);
recording = true;
data_offset = 0;
Serial.println("Recording ready.");
}
else if (button.justReleased())
{
Serial.println("Stopped recording.");
webSocket.sendTXT("{\"role\": \"user\", \"type\": \"audio\", \"format\": \"bytes.raw\", \"end\": true}");
flush_microphone();
recording = false;
data_offset = 0;
}
else if (recording)
{
Serial.printf("Reading chunk at %d...\n", data_offset);
size_t bytes_read;
i2s_read(
SPEAKER_I2S_NUMBER,
(char *)(microphonedata0 + data_offset),
DATA_SIZE, &bytes_read, (100 / portTICK_RATE_MS));
data_offset += bytes_read;
Serial.printf("Read %d bytes in chunk.\n", bytes_read);
if (data_offset > 1024 * 9)
{
flush_microphone();
}
}
M5.update();
webSocket.loop();
}
}

@ -0,0 +1,47 @@
#!/usr/bin/env python
"""A basic echo server for testing the device."""
import asyncio
import uuid
import websockets
from websockets.server import serve
import traceback
def divide_chunks(l, n):
# looping till length l
for i in range(0, len(l), n):
yield l[i : i + n]
buffers: dict[uuid.UUID, bytearray] = {}
async def echo(websocket: websockets.WebSocketServerProtocol):
async for message in websocket:
try:
if message == "s":
print("starting stream for", websocket.id)
buffers[websocket.id] = bytearray()
elif message == "e":
print("end, echoing stream for", websocket.id)
await websocket.send("s")
for chunk in divide_chunks(buffers[websocket.id], 1000):
await websocket.send(chunk)
await websocket.send("e")
elif type(message) is bytes:
print("recvd", len(message), "bytes from", websocket.id)
buffers[websocket.id].extend(message)
else:
print("ERR: recvd unknown message", message[:10], "from", websocket.id)
except Exception as _e:
traceback.print_exc()
async def main():
async with serve(echo, "0.0.0.0", 9001):
await asyncio.Future() # run forever
asyncio.run(main())

@ -0,0 +1,10 @@
from ..base_device import Device
device = Device()
def main(server_url):
device.server_url = server_url
device.start()
if __name__ == "__main__":
main()

@ -0,0 +1,10 @@
from ..base_device import Device
device = Device()
def main(server_url):
device.server_url = server_url
device.start()
if __name__ == "__main__":
main()

@ -0,0 +1,9 @@
from ..base_device import Device
device = Device()
def main():
device.start()
if __name__ == "__main__":
main()

@ -0,0 +1,10 @@
import os
import sys
import pytest
from fastapi.testclient import TestClient
from .server import app
@pytest.fixture
def client():
return TestClient(app)

@ -0,0 +1,58 @@
from dotenv import load_dotenv
load_dotenv() # take environment variables from .env.
from platformdirs import user_data_dir
import os
import glob
import json
from pathlib import Path
from interpreter import OpenInterpreter
from .system_messages.BaseSystemMessage import system_message
def configure_interpreter(interpreter: OpenInterpreter):
### SYSTEM MESSAGE
interpreter.system_message = system_message
### LLM SETTINGS
# Local settings
# interpreter.llm.model = "local"
# interpreter.llm.api_base = "https://localhost:8080/v1" # Llamafile default
# interpreter.llm.max_tokens = 1000
# interpreter.llm.context_window = 3000
# Hosted settings
interpreter.llm.api_key = os.getenv('OPENAI_API_KEY')
interpreter.llm.model = "gpt-4"
### MISC SETTINGS
interpreter.auto_run = True
interpreter.computer.languages = [l for l in interpreter.computer.languages if l.name.lower() in ["applescript", "shell", "zsh", "bash", "python"]]
interpreter.force_task_completion = False
interpreter.offline = True
interpreter.id = 206 # Used to identify itself to other interpreters. This should be changed programatically so it's unique.
### RESET conversations/user.json
app_dir = user_data_dir('01')
conversations_dir = os.path.join(app_dir, 'conversations')
os.makedirs(conversations_dir, exist_ok=True)
user_json_path = os.path.join(conversations_dir, 'user.json')
with open(user_json_path, 'w') as file:
json.dump([], file)
### SKILLS
skills_dir = user_data_dir('01', 'skills')
interpreter.computer.skills.path = skills_dir
interpreter.computer.skills.import_skills()
interpreter.computer.run("python", "tasks=[]")
interpreter.computer.api_base = "https://oi-video-frame.vercel.app/"
interpreter.computer.run("python","print('test')")
return interpreter

@ -0,0 +1,28 @@
from dotenv import load_dotenv
load_dotenv() # take environment variables from .env.
import os
import subprocess
from pathlib import Path
### LLM SETUP
# Define the path to a llamafile
llamafile_path = Path(__file__).parent / 'model.llamafile'
# Check if the new llamafile exists, if not download it
if not os.path.exists(llamafile_path):
subprocess.run(
[
"wget",
"-O",
llamafile_path,
"https://huggingface.co/jartine/phi-2-llamafile/resolve/main/phi-2.Q4_K_M.llamafile",
],
check=True,
)
# Make the new llamafile executable
subprocess.run(["chmod", "+x", llamafile_path], check=True)
# Run the new llamafile
subprocess.run([str(llamafile_path)], check=True)

@ -0,0 +1,453 @@
from dotenv import load_dotenv
load_dotenv() # take environment variables from .env.
from platformdirs import user_data_dir
import ast
import json
import queue
import os
import traceback
from .utils.bytes_to_wav import bytes_to_wav
import re
from fastapi import FastAPI, Request
from fastapi.responses import PlainTextResponse
from starlette.websockets import WebSocket, WebSocketDisconnect
from pathlib import Path
import asyncio
import urllib.parse
from .utils.kernel import put_kernel_messages_into_queue
from .i import configure_interpreter
from interpreter import interpreter
from ..utils.accumulator import Accumulator
from .utils.logs import setup_logging
from .utils.logs import logger
from ..utils.print_markdown import print_markdown
markdown = """
*Starting...*
"""
print("")
print_markdown(markdown)
print("")
setup_logging()
accumulator = Accumulator()
app = FastAPI()
app_dir = user_data_dir('01')
conversation_history_path = os.path.join(app_dir, 'conversations', 'user.json')
SERVER_LOCAL_PORT = int(os.getenv('SERVER_LOCAL_PORT', 8000))
# This is so we only say() full sentences
def is_full_sentence(text):
return text.endswith(('.', '!', '?'))
def split_into_sentences(text):
return re.split(r'(?<=[.!?])\s+', text)
# Queues
from_computer = queue.Queue() # Just for computer messages from the device. Sync queue because interpreter.run is synchronous
from_user = asyncio.Queue() # Just for user messages from the device.
to_device = asyncio.Queue() # For messages we send.
# Switch code executor to device if that's set
if os.getenv('CODE_RUNNER') == "device":
# (This should probably just loop through all languages and apply these changes instead)
class Python:
# This is the name that will appear to the LLM.
name = "python"
def __init__(self):
self.halt = False
def run(self, code):
"""Generator that yields a dictionary in LMC Format."""
# Prepare the data
message = {"role": "assistant", "type": "code", "format": "python", "content": code}
# Unless it was just sent to the device, send it wrapped in flags
if not (interpreter.messages and interpreter.messages[-1] == message):
to_device.put({"role": "assistant", "type": "code", "format": "python", "start": True})
to_device.put(message)
to_device.put({"role": "assistant", "type": "code", "format": "python", "end": True})
# Stream the response
logger.info("Waiting for the device to respond...")
while True:
chunk = from_computer.get()
logger.info(f"Server received from device: {chunk}")
if "end" in chunk:
break
yield chunk
def stop(self):
self.halt = True
def terminate(self):
"""Terminates the entire process."""
# dramatic!! do nothing
pass
interpreter.computer.languages = [Python]
# Configure interpreter
interpreter = configure_interpreter(interpreter)
@app.get("/ping")
async def ping():
return PlainTextResponse("pong")
@app.websocket("/")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
receive_task = asyncio.create_task(receive_messages(websocket))
send_task = asyncio.create_task(send_messages(websocket))
try:
await asyncio.gather(receive_task, send_task)
except Exception as e:
logger.debug(traceback.format_exc())
logger.info(f"Connection lost. Error: {e}")
@app.post("/")
async def add_computer_message(request: Request):
body = await request.json()
text = body.get("text")
if not text:
return {"error": "Missing 'text' in request body"}, 422
message = {"role": "computer", "type": "console", "format": "output", "content": text}
from_computer.put({"role": "computer", "type": "console", "format": "output", "start": True})
from_computer.put(message)
from_computer.put({"role": "computer", "type": "console", "format": "output", "end": True})
async def receive_messages(websocket: WebSocket):
while True:
try:
try:
data = await websocket.receive()
except Exception as e:
print(str(e))
return
if 'text' in data:
try:
data = json.loads(data['text'])
if data["role"] == "computer":
from_computer.put(data) # To be handled by interpreter.computer.run
elif data["role"] == "user":
await from_user.put(data)
else:
raise("Unknown role:", data)
except json.JSONDecodeError:
pass # data is not JSON, leave it as is
elif 'bytes' in data:
data = data['bytes'] # binary data
await from_user.put(data)
except WebSocketDisconnect as e:
if e.code == 1000:
logger.info("Websocket connection closed normally.")
return
else:
raise
async def send_messages(websocket: WebSocket):
while True:
message = await to_device.get()
logger.debug(f"Sending to the device: {type(message)} {message}")
try:
if isinstance(message, dict):
await websocket.send_json(message)
elif isinstance(message, bytes):
await websocket.send_bytes(message)
else:
raise TypeError("Message must be a dict or bytes")
except:
# Make sure to put the message back in the queue if you failed to send it
await to_device.put(message)
raise
async def listener():
while True:
while True:
if not from_user.empty():
chunk = await from_user.get()
break
elif not from_computer.empty():
chunk = from_computer.get()
break
await asyncio.sleep(1)
message = accumulator.accumulate(chunk)
if message == None:
# Will be None until we have a full message ready
continue
# print(str(message)[:1000])
# At this point, we have our message
if message["type"] == "audio" and message["format"].startswith("bytes"):
if not message["content"]: # If it was nothing / silence
continue
# Convert bytes to audio file
# Format will be bytes.wav or bytes.opus
mime_type = "audio/" + message["format"].split(".")[1]
audio_file_path = bytes_to_wav(message["content"], mime_type)
# For microphone debugging:
if False:
os.system(f"open {audio_file_path}")
import time
time.sleep(15)
text = stt(audio_file_path)
print(text)
message = {"role": "user", "type": "message", "content": text}
# At this point, we have only text messages
# Custom stop message will halt us
if message["content"].lower().strip(".,! ") == "stop":
continue
# Load, append, and save conversation history
with open(conversation_history_path, 'r') as file:
messages = json.load(file)
messages.append(message)
with open(conversation_history_path, 'w') as file:
json.dump(messages, file, indent=4)
accumulated_text = ""
force_task_completion_message = """AUTOMATED MESSAGE: Proceed. You CAN run code on my machine. If you want to run code, start your message with "```"! If the entire task I asked for is done, say exactly 'The task is done.' If you need some specific information (like username or password) say EXACTLY 'Please provide more information.' If it's impossible, say 'The task is impossible.' (If I haven't provided a task, say exactly 'Let me know what you'd like to do next.') Otherwise keep going."""
interpreter.messages = [m for m in interpreter.messages if m["content"] != force_task_completion_message]
insert_force_task_completion_message = True
if any([m["type"] == "image" for m in messages]) and interpreter.llm.model.startswith("gpt-"):
interpreter.llm.model = "gpt-4-vision-preview"
interpreter.llm.supports_vision = True
while insert_force_task_completion_message == True:
for chunk in interpreter.chat(messages, stream=True, display=True):
if chunk["type"] == "code":
insert_force_task_completion_message = False
if any([m["type"] == "image" for m in interpreter.messages]):
interpreter.llm.model = "gpt-4-vision-preview"
logger.debug("Got chunk:", chunk)
# Send it to the user
await to_device.put(chunk)
# Yield to the event loop, so you actually send it out
await asyncio.sleep(0.01)
if os.getenv('TTS_RUNNER') == "server":
# Speak full sentences out loud
if chunk["role"] == "assistant" and "content" in chunk and chunk["type"] == "message":
accumulated_text += chunk["content"]
sentences = split_into_sentences(accumulated_text)
# If we're going to speak, say we're going to stop sending text.
# This should be fixed probably, we should be able to do both in parallel, or only one.
if any(is_full_sentence(sentence) for sentence in sentences):
await to_device.put({"role": "assistant", "type": "message", "end": True})
if is_full_sentence(sentences[-1]):
for sentence in sentences:
await stream_tts_to_device(sentence)
accumulated_text = ""
else:
for sentence in sentences[:-1]:
await stream_tts_to_device(sentence)
accumulated_text = sentences[-1]
# If we're going to speak, say we're going to stop sending text.
# This should be fixed probably, we should be able to do both in parallel, or only one.
if any(is_full_sentence(sentence) for sentence in sentences):
await to_device.put({"role": "assistant", "type": "message", "start": True})
# If we have a new message, save our progress and go back to the top
if not from_user.empty():
# Check if it's just an end flag. We ignore those.
temp_message = await from_user.get()
if type(temp_message) is dict and temp_message.get("role") == "user" and temp_message.get("end"):
# Yup. False alarm.
continue
else:
# Whoops! Put that back
await from_user.put(temp_message)
with open(conversation_history_path, 'w') as file:
json.dump(interpreter.messages, file, indent=4)
# TODO: is triggering seemingly randomly
#logger.info("New user message recieved. Breaking.")
#break
# Also check if there's any new computer messages
if not from_computer.empty():
with open(conversation_history_path, 'w') as file:
json.dump(interpreter.messages, file, indent=4)
logger.info("New computer message recieved. Breaking.")
break
else:
with open(conversation_history_path, 'w') as file:
json.dump(interpreter.messages, file, indent=4)
force_task_completion_responses = [
"the task is done.",
"the task is impossible.",
"let me know what you'd like to do next.",
"please provide more information.",
]
# Did the LLM respond with one of the key messages?
if (
interpreter.messages
and any(
task_status in interpreter.messages[-1].get("content", "").lower()
for task_status in force_task_completion_responses
)
):
insert_force_task_completion_message = False
break
if insert_force_task_completion_message:
interpreter.messages += [
{
"role": "user",
"type": "message",
"content": force_task_completion_message,
}
]
else:
break
async def stream_tts_to_device(sentence):
force_task_completion_responses = [
"the task is done",
"the task is impossible",
"let me know what you'd like to do next",
]
if sentence.lower().strip().strip(".!?").strip() in force_task_completion_responses:
return
for chunk in stream_tts(sentence):
await to_device.put(chunk)
def stream_tts(sentence):
audio_file = tts(sentence)
with open(audio_file, "rb") as f:
audio_bytes = f.read()
os.remove(audio_file)
file_type = "bytes.raw"
chunk_size = 1024
# Stream the audio
yield {"role": "assistant", "type": "audio", "format": file_type, "start": True}
for i in range(0, len(audio_bytes), chunk_size):
chunk = audio_bytes[i:i+chunk_size]
yield chunk
yield {"role": "assistant", "type": "audio", "format": file_type, "end": True}
from uvicorn import Config, Server
import os
import platform
from importlib import import_module
# these will be overwritten
HOST = ''
PORT = 0
@app.on_event("startup")
async def startup_event():
server_url = f"{HOST}:{PORT}"
print("")
print_markdown(f"\n*Ready.*\n")
print("")
@app.on_event("shutdown")
async def shutdown_event():
print_markdown("*Server is shutting down*")
async def main(server_host, server_port, llm_service, model, llm_supports_vision, llm_supports_functions, context_window, max_tokens, temperature, tts_service, stt_service):
global HOST
global PORT
PORT = server_port
HOST = server_host
# Setup services
application_directory = user_data_dir('01')
services_directory = os.path.join(application_directory, 'services')
service_dict = {'llm': llm_service, 'tts': tts_service, 'stt': stt_service}
for service in service_dict:
service_directory = os.path.join(services_directory, service, service_dict[service])
# This is the folder they can mess around in
config = {"service_directory": service_directory}
if service == "llm":
config.update({
"interpreter": interpreter,
"model": model,
"llm_supports_vision": llm_supports_vision,
"llm_supports_functions": llm_supports_functions,
"context_window": context_window,
"max_tokens": max_tokens,
"temperature": temperature
})
module = import_module(f'.server.services.{service}.{service_dict[service]}.{service}', package='_01OS')
ServiceClass = getattr(module, service.capitalize())
service_instance = ServiceClass(config)
globals()[service] = getattr(service_instance, service)
interpreter.llm.completions = llm
# Start listening
asyncio.create_task(listener())
# Start watching the kernel if it's your job to do that
if True: # in the future, code can run on device. for now, just server.
asyncio.create_task(put_kernel_messages_into_queue(from_computer))
config = Config(app, host=server_host, port=int(server_port), lifespan='on')
server = Server(config)
await server.serve()
# Run the FastAPI app
if __name__ == "__main__":
asyncio.run(main())

@ -0,0 +1,15 @@
class Llm:
def __init__(self, config):
# Litellm is used by OI by default, so we just modify OI
interpreter = config["interpreter"]
config.pop("interpreter", None)
config.pop("service_directory", None)
for key, value in config.items():
setattr(interpreter, key.replace("-", "_"), value)
self.llm = interpreter.llm.completions

@ -0,0 +1,49 @@
import os
import subprocess
import requests
import json
class Llm:
def __init__(self, config):
self.install(config["service_directory"])
def install(self, service_directory):
LLM_FOLDER_PATH = service_directory
self.llm_directory = os.path.join(LLM_FOLDER_PATH, 'llm')
if not os.path.isdir(self.llm_directory): # Check if the LLM directory exists
os.makedirs(LLM_FOLDER_PATH, exist_ok=True)
# Install WasmEdge
subprocess.run(['curl', '-sSf', 'https://raw.githubusercontent.com/WasmEdge/WasmEdge/master/utils/install.sh', '|', 'bash', '-s', '--', '--plugin', 'wasi_nn-ggml'])
# Download the Qwen1.5-0.5B-Chat model GGUF file
MODEL_URL = "https://huggingface.co/second-state/Qwen1.5-0.5B-Chat-GGUF/resolve/main/Qwen1.5-0.5B-Chat-Q5_K_M.gguf"
subprocess.run(['curl', '-LO', MODEL_URL], cwd=self.llm_directory)
# Download the llama-api-server.wasm app
APP_URL = "https://github.com/LlamaEdge/LlamaEdge/releases/latest/download/llama-api-server.wasm"
subprocess.run(['curl', '-LO', APP_URL], cwd=self.llm_directory)
# Run the API server
subprocess.run(['wasmedge', '--dir', '.:.', '--nn-preload', 'default:GGML:AUTO:Qwen1.5-0.5B-Chat-Q5_K_M.gguf', 'llama-api-server.wasm', '-p', 'llama-2-chat'], cwd=self.llm_directory)
print("LLM setup completed.")
else:
print("LLM already set up. Skipping download.")
def llm(self, messages):
url = "http://localhost:8080/v1/chat/completions"
headers = {
'accept': 'application/json',
'Content-Type': 'application/json'
}
data = {
"messages": messages,
"model": "llama-2-chat"
}
with requests.post(url, headers=headers, data=json.dumps(data), stream=True) as response:
for line in response.iter_lines():
if line:
yield json.loads(line)

@ -0,0 +1,84 @@
import os
import platform
import subprocess
import time
import wget
import stat
class Llm:
def __init__(self, config):
self.interpreter = config["interpreter"]
config.pop("interpreter", None)
self.install(config["service_directory"])
config.pop("service_directory", None)
for key, value in config.items():
setattr(self.interpreter, key.replace("-", "_"), value)
self.llm = self.interpreter.llm.completions
def install(self, service_directory):
if platform.system() == "Darwin": # Check if the system is MacOS
result = subprocess.run(
["xcode-select", "-p"], stdout=subprocess.PIPE, stderr=subprocess.STDOUT
)
if result.returncode != 0:
print(
"Llamafile requires Mac users to have Xcode installed. You can install Xcode from https://developer.apple.com/xcode/ .\n\nAlternatively, you can use `LM Studio`, `Jan.ai`, or `Ollama` to manage local language models. Learn more at https://docs.openinterpreter.com/guides/running-locally ."
)
time.sleep(3)
raise Exception("Xcode is not installed. Please install Xcode and try again.")
# Define the path to the models directory
models_dir = os.path.join(service_directory, "models")
# Check and create the models directory if it doesn't exist
if not os.path.exists(models_dir):
os.makedirs(models_dir)
# Define the path to the new llamafile
llamafile_path = os.path.join(models_dir, "phi-2.Q4_K_M.llamafile")
# Check if the new llamafile exists, if not download it
if not os.path.exists(llamafile_path):
print(
"Attempting to download the `Phi-2` language model. This may take a few minutes."
)
time.sleep(3)
url = "https://huggingface.co/jartine/phi-2-llamafile/resolve/main/phi-2.Q4_K_M.llamafile"
wget.download(url, llamafile_path)
# Make the new llamafile executable
if platform.system() != "Windows":
st = os.stat(llamafile_path)
os.chmod(llamafile_path, st.st_mode | stat.S_IEXEC)
# Run the new llamafile in the background
if os.path.exists(llamafile_path):
try:
# Test if the llamafile is executable
subprocess.check_call([llamafile_path])
except subprocess.CalledProcessError:
print("The llamafile is not executable. Please check the file permissions.")
raise
subprocess.Popen([llamafile_path, "-ngl", "9999"])
else:
error_message = "The llamafile does not exist or is corrupted. Please ensure it has been downloaded correctly or try again."
print(error_message)
print(error_message)
self.interpreter.system_message = "You are Open Interpreter, a world-class programmer that can execute code on the user's machine."
self.interpreter.offline = True
self.interpreter.llm.model = "local"
self.interpreter.llm.temperature = 0
self.interpreter.llm.api_base = "https://localhost:8080/v1"
self.interpreter.llm.max_tokens = 1000
self.interpreter.llm.context_window = 3000
self.interpreter.llm.supports_functions = False

@ -0,0 +1,137 @@
"""
Defines a function which takes a path to an audio file and turns it into text.
"""
from datetime import datetime
import os
import contextlib
import tempfile
import shutil
import ffmpeg
import subprocess
import os
import subprocess
class Stt:
def __init__(self, config):
self.service_directory = config["service_directory"]
install(self.service_directory)
def stt(self, audio_file_path):
return stt(self.service_directory, audio_file_path)
def install(service_dir):
### INSTALL
WHISPER_RUST_PATH = os.path.join(service_dir, "whisper-rust")
script_dir = os.path.dirname(os.path.realpath(__file__))
source_whisper_rust_path = os.path.join(script_dir, "whisper-rust")
if not os.path.exists(source_whisper_rust_path):
print(f"Source directory does not exist: {source_whisper_rust_path}")
exit(1)
if not os.path.exists(WHISPER_RUST_PATH):
shutil.copytree(source_whisper_rust_path, WHISPER_RUST_PATH)
os.chdir(WHISPER_RUST_PATH)
# Check if whisper-rust executable exists before attempting to build
if not os.path.isfile(os.path.join(WHISPER_RUST_PATH, "target/release/whisper-rust")):
# Check if Rust is installed. Needed to build whisper executable
rust_check = subprocess.call('command -v rustc', shell=True)
if rust_check != 0:
print("Rust is not installed or is not in system PATH. Please install Rust before proceeding.")
exit(1)
# Build Whisper Rust executable if not found
subprocess.call('cargo build --release', shell=True)
else:
print("Whisper Rust executable already exists. Skipping build.")
WHISPER_MODEL_PATH = os.path.join(service_dir, "model")
WHISPER_MODEL_NAME = os.getenv('WHISPER_MODEL_NAME', 'ggml-tiny.en.bin')
WHISPER_MODEL_URL = os.getenv('WHISPER_MODEL_URL', 'https://huggingface.co/ggerganov/whisper.cpp/resolve/main/')
if not os.path.isfile(os.path.join(WHISPER_MODEL_PATH, WHISPER_MODEL_NAME)):
os.makedirs(WHISPER_MODEL_PATH, exist_ok=True)
subprocess.call(f'curl -L "{WHISPER_MODEL_URL}{WHISPER_MODEL_NAME}" -o "{os.path.join(WHISPER_MODEL_PATH, WHISPER_MODEL_NAME)}"', shell=True)
else:
print("Whisper model already exists. Skipping download.")
def convert_mime_type_to_format(mime_type: str) -> str:
if mime_type == "audio/x-wav" or mime_type == "audio/wav":
return "wav"
if mime_type == "audio/webm":
return "webm"
if mime_type == "audio/raw":
return "dat"
return mime_type
@contextlib.contextmanager
def export_audio_to_wav_ffmpeg(audio: bytearray, mime_type: str) -> str:
temp_dir = tempfile.gettempdir()
# Create a temporary file with the appropriate extension
input_ext = convert_mime_type_to_format(mime_type)
input_path = os.path.join(temp_dir, f"input_{datetime.now().strftime('%Y%m%d%H%M%S%f')}.{input_ext}")
with open(input_path, 'wb') as f:
f.write(audio)
# Check if the input file exists
assert os.path.exists(input_path), f"Input file does not exist: {input_path}"
# Export to wav
output_path = os.path.join(temp_dir, f"output_{datetime.now().strftime('%Y%m%d%H%M%S%f')}.wav")
print(mime_type, input_path, output_path)
if mime_type == "audio/raw":
ffmpeg.input(
input_path,
f='s16le',
ar='16000',
ac=1,
).output(output_path).run()
else:
ffmpeg.input(input_path).output(output_path, acodec='pcm_s16le', ac=1, ar='16k').run()
try:
yield output_path
finally:
os.remove(input_path)
os.remove(output_path)
def run_command(command):
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
return result.stdout, result.stderr
def get_transcription_file(service_directory, wav_file_path: str):
local_path = os.path.join(service_directory, 'model')
whisper_rust_path = os.path.join(service_directory, 'whisper-rust', 'target', 'release')
model_name = os.getenv('WHISPER_MODEL_NAME', 'ggml-tiny.en.bin')
output, _ = run_command([
os.path.join(whisper_rust_path, 'whisper-rust'),
'--model-path', os.path.join(local_path, model_name),
'--file-path', wav_file_path
])
return output
def stt_wav(service_directory, wav_file_path: str):
temp_dir = tempfile.gettempdir()
output_path = os.path.join(temp_dir, f"output_stt_{datetime.now().strftime('%Y%m%d%H%M%S%f')}.wav")
ffmpeg.input(wav_file_path).output(output_path, acodec='pcm_s16le', ac=1, ar='16k').run()
try:
transcript = get_transcription_file(service_directory, output_path)
finally:
os.remove(output_path)
return transcript
def stt(service_directory, input_data):
return stt_wav(service_directory, input_data)

@ -0,0 +1,10 @@
# Generated by Cargo
# will have compiled files and executables
debug/
target/
# These are backup files generated by rustfmt
**/*.rs.bk
# MSVC Windows builds of rustc generate these, which store debugging information
*.pdb

File diff suppressed because it is too large Load Diff

@ -0,0 +1,14 @@
[package]
name = "whisper-rust"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
anyhow = "1.0.79"
clap = { version = "4.4.18", features = ["derive"] }
cpal = "0.15.2"
hound = "3.5.1"
whisper-rs = "0.10.0"
whisper-rs-sys = "0.8.0"

@ -0,0 +1,34 @@
mod transcribe;
use clap::Parser;
use std::path::PathBuf;
use transcribe::transcribe;
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// This is the model for Whisper STT
#[arg(short, long, value_parser, required = true)]
model_path: PathBuf,
/// This is the wav audio file that will be converted from speech to text
#[arg(short, long, value_parser, required = true)]
file_path: Option<PathBuf>,
}
fn main() {
let args = Args::parse();
let file_path = match args.file_path {
Some(fp) => fp,
None => panic!("No file path provided")
};
let result = transcribe(&args.model_path, &file_path);
match result {
Ok(transcription) => print!("{}", transcription),
Err(e) => panic!("Error: {}", e),
}
}

@ -0,0 +1,64 @@
use whisper_rs::{FullParams, SamplingStrategy, WhisperContext, WhisperContextParameters};
use std::path::PathBuf;
/// Transcribes the given audio file using the whisper-rs library.
///
/// # Arguments
/// * `model_path` - Path to Whisper model file
/// * `file_path` - A string slice that holds the path to the audio file to be transcribed.
///
/// # Returns
///
/// A Result containing a String with the transcription if successful, or an error message if not.
pub fn transcribe(model_path: &PathBuf, file_path: &PathBuf) -> Result<String, String> {
let model_path_str = model_path.to_str().expect("Not valid model path");
// Load a context and model
let ctx = WhisperContext::new_with_params(
model_path_str, // Replace with the actual path to the model
WhisperContextParameters::default(),
)
.map_err(|_| "failed to load model")?;
// Create a state
let mut state = ctx.create_state().map_err(|_| "failed to create state")?;
// Create a params object
// Note that currently the only implemented strategy is Greedy, BeamSearch is a WIP
let mut params = FullParams::new(SamplingStrategy::Greedy { best_of: 1 });
// Edit parameters as needed
params.set_n_threads(1); // Set the number of threads to use
params.set_translate(true); // Enable translation
params.set_language(Some("en")); // Set the language to translate to English
// Disable printing to stdout
params.set_print_special(false);
params.set_print_progress(false);
params.set_print_realtime(false);
params.set_print_timestamps(false);
// Load the audio file
let audio_data = std::fs::read(file_path)
.map_err(|e| format!("failed to read audio file: {}", e))?
.chunks_exact(2)
.map(|chunk| i16::from_ne_bytes([chunk[0], chunk[1]]))
.collect::<Vec<i16>>();
// Convert the audio data to the required format (16KHz mono i16 samples)
let audio_data = whisper_rs::convert_integer_to_float_audio(&audio_data);
// Run the model
state.full(params, &audio_data[..]).map_err(|_| "failed to run model")?;
// Fetch the results
let num_segments = state.full_n_segments().map_err(|_| "failed to get number of segments")?;
let mut transcription = String::new();
for i in 0..num_segments {
let segment = state.full_get_segment_text(i).map_err(|_| "failed to get segment")?;
transcription.push_str(&segment);
transcription.push('\n');
}
Ok(transcription)
}

@ -0,0 +1,110 @@
class Stt:
def __init__(self, config):
pass
def stt(self, audio_file_path):
return stt(audio_file_path)
from datetime import datetime
import os
import contextlib
import tempfile
import ffmpeg
import subprocess
import openai
from openai import OpenAI
client = OpenAI()
def convert_mime_type_to_format(mime_type: str) -> str:
if mime_type == "audio/x-wav" or mime_type == "audio/wav":
return "wav"
if mime_type == "audio/webm":
return "webm"
if mime_type == "audio/raw":
return "dat"
return mime_type
@contextlib.contextmanager
def export_audio_to_wav_ffmpeg(audio: bytearray, mime_type: str) -> str:
temp_dir = tempfile.gettempdir()
# Create a temporary file with the appropriate extension
input_ext = convert_mime_type_to_format(mime_type)
input_path = os.path.join(temp_dir, f"input_{datetime.now().strftime('%Y%m%d%H%M%S%f')}.{input_ext}")
with open(input_path, 'wb') as f:
f.write(audio)
# Check if the input file exists
assert os.path.exists(input_path), f"Input file does not exist: {input_path}"
# Export to wav
output_path = os.path.join(temp_dir, f"output_{datetime.now().strftime('%Y%m%d%H%M%S%f')}.wav")
print(mime_type, input_path, output_path)
if mime_type == "audio/raw":
ffmpeg.input(
input_path,
f='s16le',
ar='16000',
ac=1,
).output(output_path).run()
else:
ffmpeg.input(input_path).output(output_path, acodec='pcm_s16le', ac=1, ar='16k').run()
try:
yield output_path
finally:
os.remove(input_path)
os.remove(output_path)
def run_command(command):
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
return result.stdout, result.stderr
def get_transcription_file(wav_file_path: str):
local_path = os.path.join(os.path.dirname(__file__), 'local_service')
whisper_rust_path = os.path.join(os.path.dirname(__file__), 'whisper-rust', 'target', 'release')
model_name = os.getenv('WHISPER_MODEL_NAME', 'ggml-tiny.en.bin')
output, error = run_command([
os.path.join(whisper_rust_path, 'whisper-rust'),
'--model-path', os.path.join(local_path, model_name),
'--file-path', wav_file_path
])
return output
def get_transcription_bytes(audio_bytes: bytearray, mime_type):
with export_audio_to_wav_ffmpeg(audio_bytes, mime_type) as wav_file_path:
return get_transcription_file(wav_file_path)
def stt_bytes(audio_bytes: bytearray, mime_type="audio/wav"):
with export_audio_to_wav_ffmpeg(audio_bytes, mime_type) as wav_file_path:
return stt_wav(wav_file_path)
def stt_wav(wav_file_path: str):
audio_file = open(wav_file_path, "rb")
try:
transcript = client.audio.transcriptions.create(
model="whisper-1",
file=audio_file,
response_format="text"
)
except openai.BadRequestError as e:
print(f"openai.BadRequestError: {e}")
return None
return transcript
def stt(input_data, mime_type="audio/wav"):
if isinstance(input_data, str):
return stt_wav(input_data)
elif isinstance(input_data, bytearray):
return stt_bytes(input_data, mime_type)
else:
raise ValueError("Input data should be either a path to a wav file (str) or audio bytes (bytearray)")

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save