155 lines
4.4 KiB
Python
155 lines
4.4 KiB
Python
# Meissa - A trainable and simple text to speech server
|
|
#
|
|
# Copyright (c) 2023 Sameer Rahmani <lxsameer@gnu.org>
|
|
#
|
|
# This program is free software; you can redistribute it and/or modify
|
|
# it under the terms of the GNU General Public License as published by
|
|
# the Free Software Foundation, version 2.
|
|
#
|
|
# This program is distributed in the hope that it will be useful,
|
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
# GNU General Public License for more details.
|
|
#
|
|
# You should have received a copy of the GNU General Public License
|
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
import sys
|
|
import asyncio
|
|
import msgpack
|
|
|
|
from meissa import utils, worker
|
|
|
|
|
|
class Server:
|
|
"""
|
|
A TTS async server.
|
|
"""
|
|
|
|
def __init__(self, ctx):
|
|
self.queue = asyncio.Queue()
|
|
self.stop_event = asyncio.Event()
|
|
self.worker_stop = asyncio.Event()
|
|
self.ctx = ctx
|
|
self.worker = None
|
|
|
|
def create_worker(self):
|
|
self.worker = asyncio.create_task(
|
|
worker.worker(
|
|
self.ctx,
|
|
self.queue,
|
|
self.worker_stop,
|
|
self.stop_event,
|
|
)
|
|
)
|
|
|
|
def stop_worker(self):
|
|
self.worker_stop.set()
|
|
self.stop_event.set()
|
|
self.empty_queue()
|
|
|
|
def empty_queue(self):
|
|
while not self.queue.empty():
|
|
job = self.queue.get_nowait()
|
|
job.task_done()
|
|
|
|
def verify_speech_payload(self, payload):
|
|
if not payload.get("text"):
|
|
return self.err(
|
|
"'text' field is mandatory for 'enqueue', 'clear_n_enqueue'"
|
|
)
|
|
return None
|
|
|
|
async def command_stop(self):
|
|
utils.info("Stopping all jobs")
|
|
self.stop_worker()
|
|
self.create_worker()
|
|
return self.ok()
|
|
|
|
async def command_clear_n_enqueue(self, job):
|
|
err = self.verify_speech_payload(job)
|
|
if err:
|
|
return err
|
|
|
|
self.stop_worker()
|
|
self.create_worker()
|
|
|
|
utils.info(f"Stopping all jobs and starting: {job}")
|
|
self.queue.put_nowait(job)
|
|
return self.ok()
|
|
|
|
async def command_enqueue(self, job):
|
|
err = self.verify_speech_payload(job)
|
|
if err:
|
|
return err
|
|
self.queue.put_nowait(job)
|
|
utils.info(f"Enqueued job: {job}")
|
|
return self.ok()
|
|
|
|
async def command_status(self, _):
|
|
return self.ok({"queue": self.queue.qsize()})
|
|
|
|
def err(self, msg):
|
|
return {"status": "error", "error": msg}
|
|
|
|
def ok(self, payload={}):
|
|
return {"status": "ok", "payload": payload}
|
|
|
|
async def handle_command(self, command_pack):
|
|
utils.log("DEBUG", f"{str(command_pack)}")
|
|
command = command_pack.get("command")
|
|
|
|
if not command:
|
|
return self.err("Not command field")
|
|
|
|
payload = command_pack.get("payload")
|
|
|
|
if payload is None:
|
|
return self.err("Not payload field")
|
|
|
|
if hasattr(self, f"command_{command}"):
|
|
command_handler = getattr(self, f"command_{command}")
|
|
return await command_handler(payload)
|
|
|
|
return self.err(f"No command '{command}'!")
|
|
|
|
async def handle_client(self, reader, writer):
|
|
self.create_worker()
|
|
|
|
while True:
|
|
data = await reader.read(1024)
|
|
if not data:
|
|
break
|
|
|
|
command = msgpack.unpackb(data)
|
|
response = await self.handle_command(command)
|
|
writer.write(msgpack.packb(response))
|
|
await writer.drain()
|
|
|
|
utils.info("Disconnecting.")
|
|
writer.close()
|
|
|
|
async def run_server(self, host, port):
|
|
server = await asyncio.start_server(self.handle_client, host, port)
|
|
|
|
addr = server.sockets[0].getsockname()
|
|
utils.info(f"Server listening on {addr}")
|
|
|
|
async with server:
|
|
try:
|
|
await server.serve_forever()
|
|
except KeyboardInterrupt:
|
|
utils.info("Shutting Down")
|
|
sys.exit()
|
|
|
|
|
|
async def start(ctx):
|
|
"""
|
|
Start a TCP socket and pass the given connection handler function `fn`
|
|
to it. It uses the `host` and `port` in the config file for the server.
|
|
"""
|
|
host = utils.config(ctx).get("host", "127.0.0.1")
|
|
port = utils.config(ctx).get("port", 6666)
|
|
|
|
server = Server(ctx)
|
|
await server.run_server(host, port)
|