meissa/meissa/server.py

155 lines
4.4 KiB
Python

# Meissa - A trainable and simple text to speech server
#
# Copyright (c) 2023 Sameer Rahmani <lxsameer@gnu.org>
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, version 2.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import sys
import asyncio
import msgpack
from meissa import utils, worker
class Server:
"""
A TTS async server.
"""
def __init__(self, ctx):
self.queue = asyncio.Queue()
self.stop_event = asyncio.Event()
self.worker_stop = asyncio.Event()
self.ctx = ctx
self.worker = None
def create_worker(self):
self.worker = asyncio.create_task(
worker.worker(
self.ctx,
self.queue,
self.worker_stop,
self.stop_event,
)
)
def stop_worker(self):
self.worker_stop.set()
self.stop_event.set()
self.empty_queue()
def empty_queue(self):
while not self.queue.empty():
job = self.queue.get_nowait()
job.task_done()
def verify_speech_payload(self, payload):
if not payload.get("text"):
return self.err(
"'text' field is mandatory for 'enqueue', 'clear_n_enqueue'"
)
return None
async def command_stop(self):
utils.info("Stopping all jobs")
self.stop_worker()
self.create_worker()
return self.ok()
async def command_clear_n_enqueue(self, job):
err = self.verify_speech_payload(job)
if err:
return err
self.stop_worker()
self.create_worker()
utils.info(f"Stopping all jobs and starting: {job}")
self.queue.put_nowait(job)
return self.ok()
async def command_enqueue(self, job):
err = self.verify_speech_payload(job)
if err:
return err
self.queue.put_nowait(job)
utils.info(f"Enqueued job: {job}")
return self.ok()
async def command_status(self, _):
return self.ok({"queue": self.queue.qsize()})
def err(self, msg):
return {"status": "error", "error": msg}
def ok(self, payload={}):
return {"status": "ok", "payload": payload}
async def handle_command(self, command_pack):
utils.log("DEBUG", f"{str(command_pack)}")
command = command_pack.get("command")
if not command:
return self.err("Not command field")
payload = command_pack.get("payload")
if payload is None:
return self.err("Not payload field")
if hasattr(self, f"command_{command}"):
command_handler = getattr(self, f"command_{command}")
return await command_handler(payload)
return self.err(f"No command '{command}'!")
async def handle_client(self, reader, writer):
self.create_worker()
while True:
data = await reader.read(1024)
if not data:
break
command = msgpack.unpackb(data)
response = await self.handle_command(command)
writer.write(msgpack.packb(response))
await writer.drain()
utils.info("Disconnecting.")
writer.close()
async def run_server(self, host, port):
server = await asyncio.start_server(self.handle_client, host, port)
addr = server.sockets[0].getsockname()
utils.info(f"Server listening on {addr}")
async with server:
try:
await server.serve_forever()
except KeyboardInterrupt:
utils.info("Shutting Down")
sys.exit()
async def start(ctx):
"""
Start a TCP socket and pass the given connection handler function `fn`
to it. It uses the `host` and `port` in the config file for the server.
"""
host = utils.config(ctx).get("host", "127.0.0.1")
port = utils.config(ctx).get("port", 6666)
server = Server(ctx)
await server.run_server(host, port)