Coverage for kgi / templates.py: 81%
81 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-23 08:53 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-23 08:53 +0000
1# SPDX-FileCopyrightText: 2026 Arcangelo Massari <arcangelo.massari@unibo.it>
2#
3# SPDX-License-Identifier: ISC
5"""Template implementations for different data formats."""
7from __future__ import annotations
9import pandas as pd
10import sqlalchemy
11from sqlalchemy import Column, MetaData, Table
12from sqlalchemy.dialects import postgresql
13from sqlalchemy.schema import CreateTable
14from sqlalchemy.sql.sqltypes import Boolean, Date, DateTime, Integer, Numeric, String
16from .base import Template
19class RDBTemplate(Template):
20 """Template for relational database format."""
22 def __init__(self, db_url):
23 self.db_url = db_url
25 def create_engine(self):
26 """Create SQLAlchemy engine."""
27 return sqlalchemy.create_engine(self.db_url)
29 def create_template(self) -> str:
30 """RDB template structure is determined by database schema."""
31 return "RDB template: structure will be determined by the database schema"
33 def fill_data(self, data: pd.DataFrame, source_name: str) -> str:
34 """Fill template with data and create SQL statements."""
35 table_name = source_name
36 engine = self.create_engine()
37 table = self._get_sqla_table(data, table_name)
39 # Convert data types to match schema before creating insert statement
40 data = data.copy()
41 for col in table.columns:
42 if isinstance(col.type, String):
43 data[col.name] = data[col.name].map(
44 lambda x: str(x) if x is not None else None
45 )
47 insert_stmt = postgresql.insert(table).values(data.to_dict(orient="records"))
49 if data.empty:
50 # Create only table structure if DataFrame is empty
51 with engine.begin() as connection:
52 inspector = sqlalchemy.inspect(engine)
53 if not inspector.has_table(table_name):
54 table.create(connection)
55 return str(CreateTable(table).compile(engine))
57 if not self._is_sql_query(table_name):
58 with engine.begin() as connection:
59 inspector = sqlalchemy.inspect(engine)
60 if inspector.has_table(table_name):
61 existing_columns = inspector.get_columns(table_name)
62 existing_column_names = set(col["name"] for col in existing_columns)
63 new_column_names = set(col.name for col in table.columns)
65 # Add missing columns
66 for col in table.columns:
67 if col.name not in existing_column_names:
68 connection.execute(
69 sqlalchemy.text(
70 f'ALTER TABLE "{table_name}" ADD COLUMN "{col.name}" {col.type}'
71 )
72 )
74 # Remove extra columns
75 for col_name in existing_column_names - new_column_names:
76 connection.execute(
77 sqlalchemy.text(
78 f'ALTER TABLE "{table_name}" DROP COLUMN "{col_name}"'
79 )
80 )
82 # Update column types if necessary
83 for col in table.columns:
84 existing_col = next(
85 (c for c in existing_columns if c["name"] == col.name), None
86 )
87 if existing_col and not isinstance(
88 existing_col["type"], col.type.__class__
89 ):
90 connection.execute(
91 sqlalchemy.text(
92 f'ALTER TABLE "{table_name}" ALTER COLUMN "{col.name}" TYPE {col.type}'
93 )
94 )
95 else:
96 # Create table if it doesn't exist
97 table.create(connection)
99 # Generate INSERT statements
100 connection.execute(insert_stmt)
102 # Generate full query for logging purposes
103 create_table_query = str(CreateTable(table).compile(engine))
104 insert_query = str(
105 insert_stmt.compile(
106 dialect=postgresql.dialect(), compile_kwargs={"literal_binds": True}
107 )
108 )
109 full_query = f"{create_table_query};{insert_query};"
111 engine.dispose()
112 return full_query
114 def _is_sql_query(self, table_name: str) -> bool:
115 """Check if table_name contains SQL keywords."""
116 sql_keywords = ["SELECT", "FROM", "WHERE", "JOIN", "GROUP BY", "ORDER BY"]
117 return any(keyword in table_name.upper() for keyword in sql_keywords)
119 def _get_sqla_table(self, df: pd.DataFrame, table_name: str):
120 """Create SQLAlchemy table from DataFrame."""
121 metadata = MetaData()
122 columns = []
124 for column_name, dtype in df.dtypes.items():
125 # Check if column contains mixed types by examining actual values
126 column_values = df[column_name].dropna()
127 has_strings = any(isinstance(val, str) for val in column_values)
128 has_numbers = any(isinstance(val, (int, float)) for val in column_values)
130 # If column has mixed strings and numbers, or contains strings, use String type
131 if has_strings or (has_strings and has_numbers):
132 col_type = String()
133 elif "int" in str(dtype):
134 col_type = Integer()
135 elif "float" in str(dtype):
136 col_type = Numeric()
137 elif "bool" in str(dtype):
138 col_type = Boolean()
139 elif "datetime" in str(dtype):
140 col_type = DateTime()
141 elif "date" in str(dtype):
142 col_type = Date()
143 else:
144 col_type = String()
146 columns.append(Column(column_name, col_type)) # type: ignore[arg-type]
148 return Table(table_name, metadata, *columns)
150 @property
151 def columns_decoded(self) -> bool:
152 return True