diff --git a/src/de/steamwar/sql/Field.java b/src/de/steamwar/sql/Field.java index 2180ca8..656fdd9 100644 --- a/src/de/steamwar/sql/Field.java +++ b/src/de/steamwar/sql/Field.java @@ -19,59 +19,16 @@ package de.steamwar.sql; -import java.sql.ResultSet; -import java.sql.SQLException; -import java.util.HashMap; -import java.util.Map; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; -public class Field { - - private static final Map, String> sqlTypeMapping = new HashMap<>(); - private static final Map, SqlTypeParser> sqlTypeParser = new HashMap<>(); - - //TODO andere Richtung Objekt -> SQL - public static void addTypeMapping(Class clazz, String sqlType, SqlTypeParser parser) { - sqlTypeMapping.put(clazz, sqlType); - sqlTypeParser.put(clazz, parser); - } - - static { - addTypeMapping(String.class, "TEXT", (rs, field) -> rs.getString(field.identifier())); - addTypeMapping(boolean.class, "INTEGER(1)", (rs, field) -> rs.getBoolean(field.identifier())); - addTypeMapping(byte.class, "INTEGER(1)", (rs, field) -> rs.getByte(field.identifier())); - addTypeMapping(short.class, "INTEGER(2)", (rs, field) -> rs.getShort(field.identifier())); - addTypeMapping(int.class, "INTEGER(4)", (rs, field) -> rs.getInt(field.identifier())); - addTypeMapping(long.class, "INTEGER(8)", (rs, field) -> rs.getLong(field.identifier())); - addTypeMapping(float.class, "REAL", (rs, field) -> rs.getFloat(field.identifier())); - addTypeMapping(double.class, "REAL", (rs, field) -> rs.getDouble(field.identifier())); - } - - private final String identifier; - private final Class type; - - private final SqlTypeParser parser; - private final String sqlType; - - public Field(String identifier, Class type) { - this.identifier = identifier; - this.type = type; - this.parser = (SqlTypeParser) sqlTypeParser.get(type); - this.sqlType = sqlTypeMapping.get(type); - } - - public String identifier() { - return identifier; - } - - public Class type() { - return type; - } - - public String sqlType() { - return sqlType; - } - - public T parse(ResultSet rs) throws SQLException { - return parser.parse(rs, this); - } +@Target(ElementType.FIELD) +@Retention(RetentionPolicy.RUNTIME) +public @interface Field { + String[] keys() default {}; + String def() default ""; + boolean nullable() default false; + boolean autoincrement() default false; } diff --git a/src/de/steamwar/sql/Row.java b/src/de/steamwar/sql/Row.java deleted file mode 100644 index 53c6db2..0000000 --- a/src/de/steamwar/sql/Row.java +++ /dev/null @@ -1,39 +0,0 @@ -/* - * This file is a part of the SteamWar software. - * - * Copyright (C) 2022 SteamWar.de-Serverteam - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - */ - -package de.steamwar.sql; - -public class Row { - - private final Table table; - private Object[] values; - - public Row(Table table, Object... values) { - this.table = table; - this.values = values; - } - - private T get(Field field) { - return (T) values[table.getFieldId(field)]; - } - - void update(Field[] fields, Object... values) { - table.update(values[table.keyIds()], fields, values); - } -} diff --git a/src/de/steamwar/sql/SelectStatement.java b/src/de/steamwar/sql/SelectStatement.java new file mode 100644 index 0000000..cc00ecc --- /dev/null +++ b/src/de/steamwar/sql/SelectStatement.java @@ -0,0 +1,72 @@ +/* + * This file is a part of the SteamWar software. + * + * Copyright (C) 2022 SteamWar.de-Serverteam + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package de.steamwar.sql; + +import java.lang.reflect.InvocationTargetException; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; + +public class SelectStatement extends Statement { + private final Table table; + + SelectStatement(Table table, String... kfields) { + this(table, "SELECT " + Arrays.stream(table.fields).map(f -> f.identifier).collect(Collectors.joining(", ")) + " FROM " + table.name + " WHERE " + Arrays.stream(kfields).map(f -> f + " = ?").collect(Collectors.joining(", "))); + } + + public SelectStatement(Table table, String sql) { + super(sql); + this.table = table; + } + + public T select(Object... values) { + return select(rs -> { + if (rs.next()) + return read(rs); + return null; + }, values); + } + + public List listSelect(Object... values) { + return select(rs -> { + List result = new ArrayList<>(); + while (rs.next()) + result.add(read(rs)); + + return result; + }, values); + } + + private T read(ResultSet rs) throws SQLException { + Object[] params = new Object[table.fields.length]; + for(int i = 0; i < params.length; i++) { + params[i] = table.fields[i].read(rs); + } + + try { + return table.constructor.newInstance(params); + } catch (InstantiationException | IllegalAccessException | InvocationTargetException e) { + throw new SecurityException(e); + } + } +} diff --git a/src/de/steamwar/sql/SqlTypeMapper.java b/src/de/steamwar/sql/SqlTypeMapper.java new file mode 100644 index 0000000..68d2ee8 --- /dev/null +++ b/src/de/steamwar/sql/SqlTypeMapper.java @@ -0,0 +1,100 @@ +/* + * This file is a part of the SteamWar software. + * + * Copyright (C) 2022 SteamWar.de-Serverteam + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package de.steamwar.sql; + +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Timestamp; +import java.util.Arrays; +import java.util.IdentityHashMap; +import java.util.Map; + +public final class SqlTypeMapper { + private static final Map, SqlTypeMapper> mappers = new IdentityHashMap<>(); + + public static SqlTypeMapper getMapper(Class clazz) { + return (SqlTypeMapper) mappers.get(clazz); + } + + static { + new SqlTypeMapper<>(String.class, "TEXT", ResultSet::getString, PreparedStatement::setString); + new SqlTypeMapper<>(Boolean.class, "BOOLEAN", ResultSet::getBoolean, PreparedStatement::setBoolean); + new SqlTypeMapper<>(Byte.class, "INTEGER(1)", ResultSet::getByte, PreparedStatement::setByte); + new SqlTypeMapper<>(Short.class, "INTEGER(2)", ResultSet::getShort, PreparedStatement::setShort); + new SqlTypeMapper<>(Integer.class, "INTEGER(4)", ResultSet::getInt, PreparedStatement::setInt); + new SqlTypeMapper<>(Long.class, "INTEGER(8)", ResultSet::getLong, PreparedStatement::setLong); + new SqlTypeMapper<>(Float.class, "REAL", ResultSet::getFloat, PreparedStatement::setFloat); + new SqlTypeMapper<>(Double.class, "REAL", ResultSet::getDouble, PreparedStatement::setDouble); + new SqlTypeMapper<>(Timestamp.class, "TIMESTAMP", ResultSet::getTimestamp, PreparedStatement::setTimestamp); + } + + public static > void ordinalEnumMapper(Class type) { + T[] enumConstants = type.getEnumConstants(); + new SqlTypeMapper<>( + type, + "INTEGER(" + (int)Math.ceil(enumConstants.length/256.0) + ")", + (rs, identifier) -> enumConstants[rs.getInt(identifier)], + (st, index, value) -> st.setInt(index, value.ordinal()) + ); + } + + public static > void nameEnumMapper(Class type) { + new SqlTypeMapper<>( + type, + "VARCHAR(" + Arrays.stream(type.getEnumConstants()).map(e -> e.name().length()).max(Integer::compareTo) + ")", + (rs, identifier) -> Enum.valueOf(type, rs.getString(identifier)), + (st, index, value) -> st.setString(index, value.name()) + ); + } + + private final String sqlType; + private final SQLReader reader; + private final SQLWriter writer; + + public SqlTypeMapper(Class clazz, String sqlType, SQLReader reader, SQLWriter writer) { + this.sqlType = sqlType; + this.reader = reader; + this.writer = writer; + mappers.put(clazz, this); + } + + public T read(ResultSet rs, String identifier) throws SQLException { + return reader.read(rs, identifier); + } + + public void write(PreparedStatement st, int index, Object value) throws SQLException { + writer.write(st, index, (T) value); + } + + public String sqlType() { + return sqlType; + } + + @FunctionalInterface + public interface SQLReader { + T read(ResultSet rs, String identifier) throws SQLException; + } + + @FunctionalInterface + public interface SQLWriter { + void write(PreparedStatement st, int index, T value) throws SQLException; + } +} diff --git a/src/de/steamwar/sql/SqlTypeParser.java b/src/de/steamwar/sql/SqlTypeParser.java deleted file mode 100644 index 377097c..0000000 --- a/src/de/steamwar/sql/SqlTypeParser.java +++ /dev/null @@ -1,27 +0,0 @@ -/* - * This file is a part of the SteamWar software. - * - * Copyright (C) 2022 SteamWar.de-Serverteam - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - */ - -package de.steamwar.sql; - -import java.sql.ResultSet; -import java.sql.SQLException; - -public interface SqlTypeParser { - T parse(ResultSet rs, Field field) throws SQLException; -} diff --git a/src/de/steamwar/sql/Statement.java b/src/de/steamwar/sql/Statement.java index bc9e32f..9626f0e 100644 --- a/src/de/steamwar/sql/Statement.java +++ b/src/de/steamwar/sql/Statement.java @@ -24,6 +24,7 @@ import java.io.FileReader; import java.io.IOException; import java.sql.*; import java.util.*; +import java.util.function.Consumer; import java.util.function.Supplier; import java.util.logging.Level; import java.util.logging.Logger; @@ -36,8 +37,10 @@ public class Statement implements AutoCloseable { private static final Deque connections = new ArrayDeque<>(); private static final int MAX_CONNECTIONS; private static final Supplier conProvider; + static final Consumer> schemaCreator; + static { - File file = new File(new File("plugins", "SpigotCore"), "mysql.properties"); + File file = new File(System.getProperty("user.home"), "mysql.properties"); if(file.exists()) { Properties properties = new Properties(); @@ -59,19 +62,20 @@ public class Statement implements AutoCloseable { throw new SecurityException("Could not create MySQL connection", e); } }; + schemaCreator = table -> {}; } else { MAX_CONNECTIONS = 1; Connection connection; try { Class.forName("org.sqlite.JDBC"); - connection = DriverManager.getConnection("jdbc:sqlite:standalone.db"); - //TODO schema + connection = DriverManager.getConnection("jdbc:sqlite:" + System.getProperty("user.home") + "/standalone.db"); } catch (SQLException | ClassNotFoundException e) { throw new SecurityException("Could not create sqlite connection", e); } conProvider = () -> connection; + schemaCreator = Table::ensureExistanceInSqlite; } } @@ -151,7 +155,8 @@ public class Statement implements AutoCloseable { } for (int i = 0; i < objects.length; i++) { - st.setObject(i + 1, objects[i]); + Object o = objects[i]; + SqlTypeMapper.getMapper(o.getClass()).write(st, i+1, o); } return runnable.run(st); diff --git a/src/de/steamwar/sql/Table.java b/src/de/steamwar/sql/Table.java index b8971f4..74dfc83 100644 --- a/src/de/steamwar/sql/Table.java +++ b/src/de/steamwar/sql/Table.java @@ -19,96 +19,113 @@ package de.steamwar.sql; +import java.lang.reflect.Constructor; import java.sql.ResultSet; import java.sql.SQLException; import java.util.*; +import java.util.function.Function; import java.util.stream.Collectors; -public class Table { +public class Table { + public static final String PRIMARY = "primary"; - private final String name; - private final Set> keys; - private final Field[] fields; - private final Map, Integer> fieldIds = new HashMap<>(); - private final List keyIds; + final String name; + final TableField[] fields; + private final Map> fieldsByIdentifier = new HashMap<>(); + final Constructor constructor; - private final Map[], Statement> cachedSelect = new HashMap<>(); - private final Map[], Statement> cachedInsert = new HashMap<>(); - private final Map[], Statement> cachedUpdate = new HashMap<>(); + private final Map[]> keys; - public Table(String name, Field[] keys, Field... fields) { + + public Table(Class clazz) { + this(clazz, clazz.getSimpleName()); + } + + public Table(Class clazz, String name) { this.name = name; - this.keys = Arrays.stream(keys).collect(Collectors.toSet()); - this.fields = fields; - for(int i = 0; i < fields.length; i++) { - fieldIds.put(fields[i], i); + this.fields = Arrays.stream(clazz.getDeclaredFields()).filter(field -> field.isAnnotationPresent(Field.class)).map(TableField::new).toArray(TableField[]::new); + try { + this.constructor = clazz.getDeclaredConstructor(Arrays.stream(clazz.getDeclaredFields()).filter(field -> field.isAnnotationPresent(Field.class)).map(java.lang.reflect.Field::getType).toArray(Class[]::new)); + } catch (NoSuchMethodException e) { + throw new SecurityException(e); } - keyIds = Arrays.stream(keys).map(fieldIds::get).collect(Collectors.toList()); - } - public Row selectSingle(Field[] fields, Object... values) { - return select(rs -> { - if(rs.next()) - return read(rs); - return null; - }, fields, values); - } + keys = Arrays.stream(fields).flatMap(field -> Arrays.stream(field.field.keys())).distinct().collect(Collectors.toMap(Function.identity(), key -> Arrays.stream(fields).filter(field -> Arrays.asList(field.field.keys()).contains(key)).toArray(TableField[]::new))); - public List selectMulti(Field[] fields, Object... values) { - return select(rs -> { - List result = new ArrayList<>(); - while(rs.next()) - result.add(read(rs)); - - return result; - }, fields, values); - } - - public void insert(Field[] fields, Object... values) { - Statement statement; - synchronized (cachedInsert) { - statement = cachedInsert.computeIfAbsent(fields, fs -> new Statement("INSERT INTO " + name + " (" + Arrays.stream(fs).map(Field::identifier).collect(Collectors.joining(", ")) + ") VALUES (" + Arrays.stream(fs).map(f -> "?").collect(Collectors.joining(", ")) + ")")); + for (TableField field : fields) { + fieldsByIdentifier.put(field.identifier, field); } - statement.update(values); + + Statement.schemaCreator.accept(this); } - public void update(Object[] keyvalues, Field[] fields, Object... values) { - Statement statement; - synchronized (cachedUpdate) { - statement = cachedUpdate.computeIfAbsent(fields, fs -> new Statement("UPDATE " + name + " SET " + Arrays.stream(fs).map(f -> f.identifier() + " = ?").collect(Collectors.joining(", ")) + " WHERE " + keys.stream().map(f -> f.identifier() + " = ?").collect(Collectors.joining(", ")))); - } - statement.update(values, keyvalues); + + public SelectStatement select(String name) { + return selectFields(keyFields(name)); + } + public SelectStatement selectFields(String... kfields) { + return new SelectStatement<>(this, kfields); } - public void create() { - //TODO syntax mysql/sqlite - try (Statement statement = new Statement("CREATE TABLE IF NOT EXISTS " + name + "(" + Arrays.stream(fields).map(field -> field.identifier() + " " + field.sqlType() + (keys.contains(field) ? " PRIMARY KEY" : "")).collect(Collectors.joining(", ")) + ") STRICT")) { + public Statement update(String name, String... fields) { + return updateFields(fields, keyFields(name)); + } + + public Statement updateField(String field, String... kfields) { + return updateFields(new String[]{field}, kfields); + } + + public Statement updateFields(String[] fields, String... kfields) { + return new Statement("UPDATE " + name + " SET " + Arrays.stream(fields).map(f -> f + " = ?").collect(Collectors.joining(", ")) + " WHERE " + Arrays.stream(kfields).map(f -> f + " = ?").collect(Collectors.joining(", "))); + } + + public Statement insert(String name) { + return insertFields(keyFields(name)); + } + + public Statement insertFields(String... fields) { + List nonKeyFields = Arrays.stream(fields).filter(f -> fieldsByIdentifier.get(f).field.keys().length == 0).collect(Collectors.toList()); + return new Statement("INSERT INTO " + name + " (" + String.join(", ", fields) + ") VALUES (" + Arrays.stream(fields).map(f -> "?").collect(Collectors.joining(", ")) + ")" + (nonKeyFields.isEmpty() ? "" : " ON DUPLICATE KEY UPDATE " + nonKeyFields.stream().map(f -> f + " = VALUES(" + f + ")").collect(Collectors.joining(", ")))); + } + + public Statement deleteWithKey(String name) { + return delete(keyFields(name)); + } + + public Statement delete(String... kfields) { + return new Statement("DELETE FROM " + name + " WHERE " + Arrays.stream(kfields).map(f -> f + " = ?").collect(Collectors.joining(", "))); + } + + void ensureExistanceInSqlite() { + List> primaryKey = keys.containsKey(PRIMARY) ? Arrays.asList(keys.get(PRIMARY)) : Collections.emptyList(); + try (Statement statement = new Statement( + "CREATE TABLE IF NOT EXISTS " + name + "(" + + Arrays.stream(fields).map(field -> field.identifier + " " + field.mapper.sqlType() + (field.field.nullable() ? "" : " NOT NULL DEFAULT NULL") + (!field.field.nullable() && field.field.def().equals("") ? "" : " DEFAULT " + field.field.def()) + (primaryKey.contains(field) ? " PRIMARY KEY" : "") + (field.field.autoincrement() ? " AUTOINCREMENT" : "")).collect(Collectors.joining(", ")) + + keys.entrySet().stream().filter(entry -> !entry.getKey().equals(PRIMARY)).map(key -> ", UNIQUE (" + Arrays.stream(key.getValue()).map(field -> field.identifier).collect(Collectors.joining(", ")) + ")").collect(Collectors.joining(" ")) + + ") STRICT, WITHOUT ROWID")) { statement.update(); } } - public List keyIds() { - return keyIds; + private String[] keyFields(String name) { + return Arrays.stream(keys.get(name)).map(f -> f.identifier).toArray(String[]::new); } - public int getFieldId(Field field) { - return fieldIds.get(field); - } + static class TableField { - private T select(Statement.ResultSetUser u, Field[] fields, Object... values) { - Statement statement; - synchronized (cachedSelect) { - statement = cachedSelect.computeIfAbsent(fields, fs -> new Statement("SELECT " + Arrays.stream(fields).map(Field::identifier).collect(Collectors.joining(", ")) + " FROM " + name + " WHERE " + Arrays.stream(fs).map(f -> f.identifier() + " = ?").collect(Collectors.joining(", ")))); - } - return statement.select(u, values); - } + final String identifier; - private Row read(ResultSet rs) throws SQLException { - Object[] values = new Object[fields.length]; - for(int i = 0; i < fields.length; i++) { - values[i] = fields[i].parse(rs); + final SqlTypeMapper mapper; + private final Field field; + + private TableField(java.lang.reflect.Field field) { + this.identifier = field.getName(); + this.mapper = (SqlTypeMapper) SqlTypeMapper.getMapper(field.getDeclaringClass()); + this.field = field.getAnnotation(Field.class); } - return new Row(this, values); + T read(ResultSet rs) throws SQLException { + return mapper.read(rs, identifier); + } } }