Coverage for kgi / endpoints.py: 35%
213 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-23 08:53 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-23 08:53 +0000
1# SPDX-FileCopyrightText: 2026 Arcangelo Massari <arcangelo.massari@unibo.it>
2#
3# SPDX-License-Identifier: ISC
5"""SPARQL endpoint implementations."""
7import gzip
8import json
9import logging
10import os
11import re
12import shutil
13import subprocess
14import tempfile
16from io import BytesIO
18from pyoxigraph import BlankNode, DefaultGraph, Literal, NamedNode, Quad, QueryResultsFormat, QuerySolutions, Store
19from sparqlite import SPARQLClient
21from .base import Endpoint
22from .utils import Validator
25class RemoteEndpoint(Endpoint):
26 """Remote SPARQL endpoint implementation."""
28 def __init__(self, url: str, rdf_file_to_load: str | None = None):
29 self._client = SPARQLClient(url)
30 self.endpoint_url = url
31 self.rdf_file_path = rdf_file_to_load
32 self._graph_uri = None
34 if rdf_file_to_load:
35 self._graph_uri = f"http://temp/graph/{os.path.basename(rdf_file_to_load)}"
36 self._load_data()
38 def _load_data(self):
39 """Load RDF data into the SPARQL endpoint using INSERT DATA."""
40 assert self.rdf_file_path is not None
41 self._client.update(f"CLEAR GRAPH <{self._graph_uri}>")
43 with open(self.rdf_file_path, "r", encoding="utf-8") as f:
44 chunk_size = 1000
45 triples = []
47 for line in f:
48 line = line.strip()
49 if line and not line.startswith("#"):
50 triples.append(line)
52 if len(triples) >= chunk_size:
53 self._insert_triples(triples)
54 triples = []
56 if triples:
57 self._insert_triples(triples)
59 def _insert_triples(self, triples):
60 """Insert a batch of triples into the SPARQL endpoint."""
61 insert_query = f"INSERT DATA {{\n GRAPH <{self._graph_uri}> {{\n"
62 for triple in triples:
63 if triple.endswith("."):
64 triple = triple[:-1].strip()
65 insert_query += f" {triple} .\n"
66 insert_query += " }\n}"
68 self._client.update(insert_query)
70 def query(self, query: str):
71 """Execute a SPARQL query and return JSON string."""
72 if self._graph_uri:
73 modified_query = query.replace(
74 "WHERE {", f"WHERE {{ GRAPH <{self._graph_uri}> {{"
75 )
76 bracket_count = modified_query.count("{") - modified_query.count("}")
77 if bracket_count > 0:
78 modified_query += "}" * bracket_count
79 query = modified_query
81 result = self._client.query(query, method="POST")
82 return json.dumps(result)
84 def __repr__(self):
85 return f"RemoteEndpoint({self.endpoint_url})"
87 def close(self):
88 self._client.close()
90 def __del__(self):
91 """Clean up by removing the graph from the endpoint."""
92 if hasattr(self, "_graph_uri") and self._graph_uri:
93 try:
94 self._client.update(f"CLEAR GRAPH <{self._graph_uri}>")
95 except Exception:
96 pass
97 if hasattr(self, "_client"):
98 self._client.close()
101class VirtuosoEndpoint(RemoteEndpoint):
102 """Virtuoso-specific endpoint that uses bulk loading for better performance."""
104 def __init__(
105 self,
106 url: str,
107 rdf_file_to_load: str | None = None,
108 container_name: str = "virtuoso-kgi",
109 ):
110 self.container_name = container_name
111 self.host_bulk_load_dir = os.environ["VIRTUOSO_BULK_DIR"]
113 self._client = SPARQLClient(url)
114 self.endpoint_url = url
115 self.rdf_file_path = rdf_file_to_load
116 self._graph_uri = None
118 if rdf_file_to_load:
119 self.rdf_file_path = rdf_file_to_load
120 self._graph_uri = f"http://temp/graph/{os.path.basename(rdf_file_to_load)}"
121 self._bulk_load_data()
123 def _bulk_load_data(self):
124 """Load RDF data using Virtuoso bulk loading instead of INSERT queries."""
125 assert self.rdf_file_path is not None
127 self._client.update(f"CLEAR GRAPH <{self._graph_uri}>")
129 # Convert N-Triples to N-Quads with target graph
130 temp_nq_file = None
131 temp_nq_gz_file = None
133 try:
134 # Create temporary N-Quads file
135 with tempfile.NamedTemporaryFile(
136 mode="w", suffix=".nq", delete=False, encoding="utf-8"
137 ) as temp_nq:
138 temp_nq_file = temp_nq.name
140 triple_count = 0
141 with open(self.rdf_file_path, "r", encoding="utf-8") as f:
142 for line in f:
143 line = line.strip()
144 if line and not line.startswith("#"):
145 if line.endswith("."):
146 line = line[:-1].strip()
147 # Add graph URI to make it an N-Quad
148 temp_nq.write(f"{line} <{self._graph_uri}> .\n")
149 triple_count += 1
151 # Compress the N-Quads file
152 temp_nq_gz_file = temp_nq_file + ".gz"
153 with open(temp_nq_file, "rb") as f_in:
154 with gzip.open(temp_nq_gz_file, "wb") as f_out:
155 shutil.copyfileobj(f_in, f_out)
157 # Copy the gzipped file to the bulk load directory
158 bulk_load_file = f"{self.host_bulk_load_dir}/temp_bulk_load.nq.gz"
159 shutil.copy2(temp_nq_gz_file, bulk_load_file)
161 # Step 1: Clear any existing entries for this file from load_list
162 clear_sql = f"DELETE FROM DB.DBA.load_list WHERE ll_file = '{self.host_bulk_load_dir}/temp_bulk_load.nq.gz'"
163 try:
164 self._execute_sql(clear_sql)
165 except Exception as e:
166 logging.getLogger("kgi").error(f"Exception running clear command: {e}")
167 raise
169 # Step 2: Register the file for bulk loading
170 register_sql = f"ld_dir('{self.host_bulk_load_dir}', 'temp_bulk_load.nq.gz', 'http://localhost:8890/DAV/ignored')"
171 try:
172 self._execute_sql(register_sql)
173 except Exception as e:
174 logging.getLogger("kgi").error(
175 f"Exception running register command: {e}"
176 )
177 raise
179 # Step 3: Run the bulk loader
180 load_sql = "rdf_loader_run()"
181 try:
182 self._execute_sql(load_sql)
183 except Exception as e:
184 logging.getLogger("kgi").error(f"Exception running bulk load: {e}")
185 raise
187 # Step 4: Verify data was loaded
188 count_query = (
189 f"SELECT COUNT(*) WHERE {{ GRAPH <{self._graph_uri}> {{ ?s ?p ?o }} }}"
190 )
191 try:
192 result = self._client.query(count_query, method="POST")
193 bindings = result["results"]["bindings"]
194 triple_count_in_graph = int(bindings[0][list(bindings[0].keys())[0]]["value"]) if bindings else 0
195 if triple_count_in_graph == 0:
196 logging.getLogger("kgi").error(
197 "WARNING: No triples were loaded into the graph!"
198 )
199 except Exception as e:
200 logging.getLogger("kgi").error(f"Could not verify loaded data: {e}")
202 finally:
203 # Clean up temporary files
204 for temp_file in [
205 temp_nq_file,
206 temp_nq_gz_file,
207 f"{self.host_bulk_load_dir}/temp_bulk_load.nq.gz",
208 ]:
209 if temp_file and os.path.exists(temp_file):
210 try:
211 os.remove(temp_file)
212 except Exception:
213 pass
215 def _execute_sql(self, sql_command):
216 """Execute SQL command using local isql."""
217 # Use local isql command
218 isql_path = "/opt/virtuoso-opensource/bin/isql"
220 if not os.path.exists(isql_path):
221 logging.getLogger("kgi").error(f"isql not found at {isql_path}")
222 raise RuntimeError(f"isql not found at {isql_path}")
224 # Execute the SQL command using isql
225 cmd = [isql_path, "localhost:1111", "dba", "dba", f"exec={sql_command}"]
227 try:
228 result = subprocess.run(cmd, capture_output=True, text=True, timeout=30)
230 if result.returncode != 0:
231 logging.getLogger("kgi").error(f"SQL execution failed: {result.stderr}")
232 raise RuntimeError(f"SQL execution failed: {result.stderr}")
234 return result.stdout
236 except subprocess.TimeoutExpired:
237 logging.getLogger("kgi").error("SQL command timed out")
238 raise RuntimeError("SQL command timed out")
239 except Exception as e:
240 logging.getLogger("kgi").error(f"Failed to execute SQL command: {e}")
241 raise RuntimeError(f"Failed to execute SQL command: {e}")
244_NT_LINE = re.compile(
245 r'(<[^>]*>|_:\S+)\s+(<[^>]*>)\s+'
246 r'(<[^>]*>|_:\S+|"(?:[^"\\]|\\.)*"(?:@[a-z]+(?:-[a-z0-9]+)*)?(?:\^\^<[^>]*>)?)'
247 r'(?:\s+(<[^>]*>))?\s*\.'
248)
251_BNODE_IRI_PREFIX = "urn:bnode:"
254def _parse_term_subject(raw: str) -> NamedNode | BlankNode:
255 if raw.startswith("<"):
256 return NamedNode(raw[1:-1])
257 return NamedNode(f"{_BNODE_IRI_PREFIX}{raw[2:]}")
260def _parse_term_object(raw: str) -> NamedNode | BlankNode | Literal:
261 if raw.startswith("<"):
262 return NamedNode(raw[1:-1])
263 if raw.startswith("_:"):
264 return NamedNode(f"{_BNODE_IRI_PREFIX}{raw[2:]}")
265 match = re.match(r'^"((?:[^"\\]|\\.)*)"(@([a-z]+(?:-[a-z0-9]+)*))?(\^\^<([^>]*)>)?$', raw)
266 if not match:
267 return Literal(raw)
268 value, _, lang, _, datatype = match.groups()
269 if datatype:
270 return Literal(value, datatype=NamedNode(datatype))
271 if lang:
272 return Literal(value, language=lang)
273 return Literal(value)
276def _parse_ntriples_preserve_bnodes(store: Store, data: str) -> None:
277 for line in data.splitlines():
278 line = line.strip()
279 if not line or line.startswith("#"):
280 continue
281 m = _NT_LINE.match(line)
282 if not m:
283 continue
284 s = _parse_term_subject(m.group(1))
285 p = NamedNode(m.group(2)[1:-1])
286 o = _parse_term_object(m.group(3))
287 g = NamedNode(m.group(4)[1:-1]) if m.group(4) else DefaultGraph()
288 store.add(Quad(s, p, o, g))
291class LocalSparqlGraphStore(Endpoint):
292 """Local pyoxigraph-based SPARQL endpoint."""
294 def __init__(self, url: str, delete_after_use: bool = False):
295 self.delete_after_use = delete_after_use
296 self._store: Store | None = Store()
298 with open(url, "r", encoding="utf-8") as f:
299 data = f.read()
301 _parse_ntriples_preserve_bnodes(self._store, data)
303 def query(self, query: str):
304 """Execute a SPARQL query on the local store and return SPARQL JSON."""
305 assert self._store is not None
306 try:
307 results = self._store.query(query, use_default_graph_as_union=True)
308 assert isinstance(results, QuerySolutions)
309 buf = BytesIO()
310 results.serialize(buf, QueryResultsFormat.JSON)
311 return buf.getvalue().decode()
312 except Exception as e:
313 logging.getLogger("kgi").error(f"Query execution error: {e}")
314 logging.getLogger("kgi").error(f"Failed query: {query}")
315 raise
317 def __del__(self):
318 """Clean up resources."""
319 if self.delete_after_use:
320 self._store = None
323class EndpointFactory:
324 """Factory for creating SPARQL endpoints."""
326 @classmethod
327 def create_from_url(cls, url: str):
328 """Create an endpoint from a URL or file path."""
329 if Validator.url(url):
330 return RemoteEndpoint(url)
331 else:
332 return LocalSparqlGraphStore(url)