Coverage for kgi / templates.py: 81%

81 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-03-23 08:53 +0000

1# SPDX-FileCopyrightText: 2026 Arcangelo Massari <arcangelo.massari@unibo.it> 

2# 

3# SPDX-License-Identifier: ISC 

4 

5"""Template implementations for different data formats.""" 

6 

7from __future__ import annotations 

8 

9import pandas as pd 

10import sqlalchemy 

11from sqlalchemy import Column, MetaData, Table 

12from sqlalchemy.dialects import postgresql 

13from sqlalchemy.schema import CreateTable 

14from sqlalchemy.sql.sqltypes import Boolean, Date, DateTime, Integer, Numeric, String 

15 

16from .base import Template 

17 

18 

19class RDBTemplate(Template): 

20 """Template for relational database format.""" 

21 

22 def __init__(self, db_url): 

23 self.db_url = db_url 

24 

25 def create_engine(self): 

26 """Create SQLAlchemy engine.""" 

27 return sqlalchemy.create_engine(self.db_url) 

28 

29 def create_template(self) -> str: 

30 """RDB template structure is determined by database schema.""" 

31 return "RDB template: structure will be determined by the database schema" 

32 

33 def fill_data(self, data: pd.DataFrame, source_name: str) -> str: 

34 """Fill template with data and create SQL statements.""" 

35 table_name = source_name 

36 engine = self.create_engine() 

37 table = self._get_sqla_table(data, table_name) 

38 

39 # Convert data types to match schema before creating insert statement 

40 data = data.copy() 

41 for col in table.columns: 

42 if isinstance(col.type, String): 

43 data[col.name] = data[col.name].map( 

44 lambda x: str(x) if x is not None else None 

45 ) 

46 

47 insert_stmt = postgresql.insert(table).values(data.to_dict(orient="records")) 

48 

49 if data.empty: 

50 # Create only table structure if DataFrame is empty 

51 with engine.begin() as connection: 

52 inspector = sqlalchemy.inspect(engine) 

53 if not inspector.has_table(table_name): 

54 table.create(connection) 

55 return str(CreateTable(table).compile(engine)) 

56 

57 if not self._is_sql_query(table_name): 

58 with engine.begin() as connection: 

59 inspector = sqlalchemy.inspect(engine) 

60 if inspector.has_table(table_name): 

61 existing_columns = inspector.get_columns(table_name) 

62 existing_column_names = set(col["name"] for col in existing_columns) 

63 new_column_names = set(col.name for col in table.columns) 

64 

65 # Add missing columns 

66 for col in table.columns: 

67 if col.name not in existing_column_names: 

68 connection.execute( 

69 sqlalchemy.text( 

70 f'ALTER TABLE "{table_name}" ADD COLUMN "{col.name}" {col.type}' 

71 ) 

72 ) 

73 

74 # Remove extra columns 

75 for col_name in existing_column_names - new_column_names: 

76 connection.execute( 

77 sqlalchemy.text( 

78 f'ALTER TABLE "{table_name}" DROP COLUMN "{col_name}"' 

79 ) 

80 ) 

81 

82 # Update column types if necessary 

83 for col in table.columns: 

84 existing_col = next( 

85 (c for c in existing_columns if c["name"] == col.name), None 

86 ) 

87 if existing_col and not isinstance( 

88 existing_col["type"], col.type.__class__ 

89 ): 

90 connection.execute( 

91 sqlalchemy.text( 

92 f'ALTER TABLE "{table_name}" ALTER COLUMN "{col.name}" TYPE {col.type}' 

93 ) 

94 ) 

95 else: 

96 # Create table if it doesn't exist 

97 table.create(connection) 

98 

99 # Generate INSERT statements 

100 connection.execute(insert_stmt) 

101 

102 # Generate full query for logging purposes 

103 create_table_query = str(CreateTable(table).compile(engine)) 

104 insert_query = str( 

105 insert_stmt.compile( 

106 dialect=postgresql.dialect(), compile_kwargs={"literal_binds": True} 

107 ) 

108 ) 

109 full_query = f"{create_table_query};{insert_query};" 

110 

111 engine.dispose() 

112 return full_query 

113 

114 def _is_sql_query(self, table_name: str) -> bool: 

115 """Check if table_name contains SQL keywords.""" 

116 sql_keywords = ["SELECT", "FROM", "WHERE", "JOIN", "GROUP BY", "ORDER BY"] 

117 return any(keyword in table_name.upper() for keyword in sql_keywords) 

118 

119 def _get_sqla_table(self, df: pd.DataFrame, table_name: str): 

120 """Create SQLAlchemy table from DataFrame.""" 

121 metadata = MetaData() 

122 columns = [] 

123 

124 for column_name, dtype in df.dtypes.items(): 

125 # Check if column contains mixed types by examining actual values 

126 column_values = df[column_name].dropna() 

127 has_strings = any(isinstance(val, str) for val in column_values) 

128 has_numbers = any(isinstance(val, (int, float)) for val in column_values) 

129 

130 # If column has mixed strings and numbers, or contains strings, use String type 

131 if has_strings or (has_strings and has_numbers): 

132 col_type = String() 

133 elif "int" in str(dtype): 

134 col_type = Integer() 

135 elif "float" in str(dtype): 

136 col_type = Numeric() 

137 elif "bool" in str(dtype): 

138 col_type = Boolean() 

139 elif "datetime" in str(dtype): 

140 col_type = DateTime() 

141 elif "date" in str(dtype): 

142 col_type = Date() 

143 else: 

144 col_type = String() 

145 

146 columns.append(Column(column_name, col_type)) # type: ignore[arg-type] 

147 

148 return Table(table_name, metadata, *columns) 

149 

150 @property 

151 def columns_decoded(self) -> bool: 

152 return True