aboutsummaryrefslogtreecommitdiffstats
path: root/scratch/semgrep/server.py
blob: becabbb5d26b9da1175354e3e8532b03981cfcce (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
#!/usr/bin/env python3
import argparse
import collections
import hashlib
import json
from http.server import BaseHTTPRequestHandler, HTTPServer

import chromadb


def checksum(string: str):
    sha256 = hashlib.sha256()
    sha256.update(string.encode("utf-8"))
    return sha256.hexdigest()[:32]


def ensure_list(data):
    if isinstance(data, str):
        return [data]
    if isinstance(data, list):
        if all(isinstance(l, str) for l in data):
            return data
    raise ValueError("Data must be a list of strings")


def delete_nodes(nodes):
    for node in nodes:
        collection.delete(where={"node-id": node})


class MyRequestHandler(BaseHTTPRequestHandler):
    def do_POST(self):
        content_length = int(self.headers["Content-Length"])
        post_data = self.rfile.read(content_length).decode("utf-8")

        try:
            data = json.loads(post_data)
            response_message = f"Received POST request with data: '{data}'\n"
            self.log_message(response_message)
        except ValueError:
            response_message = "Invalid JSON data"
            self.send_response(400)

        if query := data.get("query"):
            self.log_message("Processing query '%s'", query.replace("\n", " ").strip())
            response = collection.query(query_texts=ensure_list(query))
        elif delete_set := data.get("delete"):
            delete_nodes(ensure_list(delete_set))
            response = f"Deleted nodes {delete_set}"
        elif paragraphs := data.get("insert"):
            data, metadata = drop_duplicates(paragraphs)
            nodes = set(m.get("node-id") for m in metadata)
            self.log_message("Processing metadata %s", nodes)
            delete_nodes(nodes)
            collection.add(
                documents=data, metadatas=metadata, ids=list(map(checksum, data))
            )
            response = f"Successfully inserted {nodes}"
        else:
            raise ValueError(f"Used wrong method. Sent: {data.keys()}")

        response_message = json.dumps(response)

        self.send_response(200)
        self.send_header("Content-type", "text/plain")
        self.end_headers()
        self.wfile.write(response_message.encode("utf-8"))


def run_server(port=8080):
    server_address = ("", port)
    httpd = HTTPServer(server_address, MyRequestHandler)
    print(f"Server running on port {port}")
    httpd.serve_forever()


def drop_duplicates(paragraphs):
    data = [data["document"].replace("\n", " ").strip() for data in paragraphs]
    metadata = [data["metadata"] for data in paragraphs]
    dups = (x for x, count in collections.Counter(data).items() if count > 1)
    to_drop = []
    for no in dups:
        to_drop.extend([i for i, d in enumerate(data) if d == no][1:])
    to_drop.sort(reverse=True)
    for index in to_drop:
        data.pop(index)
        metadata.pop(index)
    return data, metadata


def parse_arguments(args=None):
    parser = argparse.ArgumentParser(
        description="Run Semantic database server",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument(
        "-db", "--database", default="org-roam", help="Name of the collection"
    )
    parser.add_argument(
        "-D",
        "--database-dir",
        default="semantic-roam",
        help="Directory where to store database files",
    )
    parser.add_argument(
        "-p", "--port", default=8080, type=int, help="Port where server listens"
    )

    return parser.parse_args(args)


if __name__ == "__main__":
    args = parse_arguments()
    client = chromadb.PersistentClient(path=args.database_dir)
    collection = client.get_or_create_collection(args.database)
    run_server(args.port)