1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
|
#!/usr/bin/env python3
import argparse
import collections
import hashlib
import json
from http.server import BaseHTTPRequestHandler, HTTPServer
import chromadb
def checksum(string: str):
sha256 = hashlib.sha256()
sha256.update(string.encode("utf-8"))
return sha256.hexdigest()[:32]
def ensure_list(data):
if isinstance(data, str):
return [data]
if isinstance(data, list):
if all(isinstance(l, str) for l in data):
return data
raise ValueError("Data must be a list of strings")
class MyRequestHandler(BaseHTTPRequestHandler):
def do_POST(self):
content_length = int(self.headers["Content-Length"])
post_data = self.rfile.read(content_length).decode("utf-8")
try:
data = json.loads(post_data)
response_message = f"Received POST request with data: '{data}'\n"
except ValueError:
response_message = "Invalid JSON data"
self.send_response(400)
if query := data.get("query"):
self.log_message("Processing query '%s'", query.replace("\n", " ").strip())
response = collection.query(query_texts=ensure_list(query))
elif paragraphs := data.get("insert"):
data, metadata = drop_duplicates(paragraphs)
nodes = set(m.get("node-id") for m in metadata)
self.log_message("Processing metadata %s", nodes)
for node in nodes:
collection.delete(where={"node-id": node})
collection.add(
documents=data, metadatas=metadata, ids=list(map(checksum, data))
)
response = f"Successfully inserted {nodes}"
else:
raise ValueError(f"Used wrong method. Sent: {data.keys()}")
response_message = json.dumps(response)
self.send_response(200)
self.send_header("Content-type", "text/plain")
self.end_headers()
self.wfile.write(response_message.encode("utf-8"))
def run_server(port=8080):
server_address = ("", port)
httpd = HTTPServer(server_address, MyRequestHandler)
print(f"Server running on port {port}")
httpd.serve_forever()
def drop_duplicates(paragraphs):
data = [data["document"].replace("\n", " ").strip() for data in paragraphs]
metadata = [data["metadata"] for data in paragraphs]
dups = (x for x, count in collections.Counter(data).items() if count > 1)
to_drop = []
for no in dups:
to_drop.extend([i for i, d in enumerate(data) if d == no][1:])
to_drop.sort(reverse=True)
for index in to_drop:
data.pop(index)
metadata.pop(index)
return data, metadata
def parse_arguments(args=None):
parser = argparse.ArgumentParser(
description="Run Semantic database server",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"-db", "--database", default="org-roam", help="Name of the collection"
)
parser.add_argument(
"-D",
"--database-dir",
default="semantic-roam",
help="Directory where to store database files",
)
parser.add_argument(
"-p", "--port", default=8080, type=int, help="Port where server listens"
)
return parser.parse_args(args)
if __name__ == "__main__":
args = parse_arguments()
client = chromadb.PersistentClient(path=args.database_dir)
collection = client.get_or_create_collection(args.database)
run_server(args.port)
|