#!/usr/bin/env python
# cardinal_pythonlib/sqlalchemy/sqlfunc.py
"""
===============================================================================
Original code copyright (C) 2009-2022 Rudolf Cardinal (rudolf@pobox.com).
This file is part of cardinal_pythonlib.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
===============================================================================
**Functions to operate on SQL clauses for SQLAlchemy Core.**
"""
from typing import NoReturn, TYPE_CHECKING
from sqlalchemy.ext.compiler import compiles
# noinspection PyProtectedMember
from sqlalchemy.sql.expression import FunctionElement
from sqlalchemy.sql.sqltypes import Numeric
from cardinal_pythonlib.sqlalchemy.dialect import SqlaDialectName
from cardinal_pythonlib.logs import get_brace_style_log_with_null_handler
if TYPE_CHECKING:
from sqlalchemy.sql.elements import ClauseElement, ClauseList
from sqlalchemy.sql.compiler import SQLCompiler
log = get_brace_style_log_with_null_handler(__name__)
# =============================================================================
# Support functions
# =============================================================================
[docs]def fail_unknown_dialect(compiler: "SQLCompiler", task: str) -> NoReturn:
"""
Raise :exc:`NotImplementedError` in relation to a dialect for which a
function hasn't been implemented (with a helpful error message).
"""
raise NotImplementedError(
f"Don't know how to {task} on dialect {compiler.dialect!r}. "
f"(Check also: if you printed the SQL before it was bound to an "
f"engine, you will be trying to use a dialect like StrSQLCompiler, "
f"which could be a reason for failure.)"
)
[docs]def fetch_processed_single_clause(
element: "ClauseElement", compiler: "SQLCompiler"
) -> str:
"""
Takes a clause element that must have a single clause, and converts it
to raw SQL text.
"""
if len(element.clauses) != 1:
raise TypeError(
f"Only one argument supported; "
f"{len(element.clauses)} were passed"
)
clauselist = element.clauses # type: ClauseList
first = next(clauselist.get_children())
return compiler.process(first)
# =============================================================================
# Extract year from a DATE/DATETIME etc.
# =============================================================================
# noinspection PyPep8Naming
# noinspection PyUnusedLocal
# noinspection PyUnusedLocal
@compiles(extract_year, SqlaDialectName.SQLSERVER)
@compiles(extract_year, SqlaDialectName.MYSQL)
def extract_year_year(
element: "ClauseElement", compiler: "SQLCompiler", **kw
) -> str:
# https://dev.mysql.com/doc/refman/5.5/en/date-and-time-functions.html#function_year # noqa: E501
# https://docs.microsoft.com/en-us/sql/t-sql/functions/year-transact-sql
clause = fetch_processed_single_clause(element, compiler)
return f"YEAR({clause})"
# noinspection PyUnusedLocal
@compiles(extract_year, SqlaDialectName.ORACLE)
@compiles(extract_year, SqlaDialectName.POSTGRES)
def extract_year_extract(
element: "ClauseElement", compiler: "SQLCompiler", **kw
) -> str:
# https://docs.oracle.com/cd/B19306_01/server.102/b14200/functions050.htm
clause = fetch_processed_single_clause(element, compiler)
return f"EXTRACT(YEAR FROM {clause})"
# noinspection PyUnusedLocal
@compiles(extract_year, SqlaDialectName.SQLITE)
def extract_year_strftime(
element: "ClauseElement", compiler: "SQLCompiler", **kw
) -> str:
# https://sqlite.org/lang_datefunc.html
clause = fetch_processed_single_clause(element, compiler)
return f"STRFTIME('%Y', {clause})"
# =============================================================================
# Extract month from a DATE/DATETIME etc.
# =============================================================================
# noinspection PyPep8Naming
# noinspection PyUnusedLocal
@compiles(extract_month)
def extract_month_default(
element: "ClauseElement", compiler: "SQLCompiler", **kw
) -> NoReturn:
fail_unknown_dialect(compiler, "extract month from date")
# noinspection PyUnusedLocal
@compiles(extract_month, SqlaDialectName.SQLSERVER)
@compiles(extract_month, SqlaDialectName.MYSQL)
def extract_month_month(
element: "ClauseElement", compiler: "SQLCompiler", **kw
) -> str:
clause = fetch_processed_single_clause(element, compiler)
return f"MONTH({clause})"
# noinspection PyUnusedLocal
@compiles(extract_month, SqlaDialectName.ORACLE)
@compiles(extract_year, SqlaDialectName.POSTGRES)
def extract_month_extract(
element: "ClauseElement", compiler: "SQLCompiler", **kw
) -> str:
clause = fetch_processed_single_clause(element, compiler)
return f"EXTRACT(MONTH FROM {clause})"
# noinspection PyUnusedLocal
@compiles(extract_month, SqlaDialectName.SQLITE)
def extract_month_strftime(
element: "ClauseElement", compiler: "SQLCompiler", **kw
) -> str:
clause = fetch_processed_single_clause(element, compiler)
return f"STRFTIME('%m', {clause})"
# =============================================================================
# Extract day (day of month, not day of week) from a DATE/DATETIME etc.
# =============================================================================
# noinspection PyPep8Naming
# noinspection PyUnusedLocal
@compiles(extract_day_of_month)
def extract_day_of_month_default(
element: "ClauseElement", compiler: "SQLCompiler", **kw
) -> NoReturn:
fail_unknown_dialect(compiler, "extract day-of-month from date")
# noinspection PyUnusedLocal
@compiles(extract_day_of_month, SqlaDialectName.SQLSERVER)
@compiles(extract_day_of_month, SqlaDialectName.MYSQL)
def extract_day_of_month_day(
element: "ClauseElement", compiler: "SQLCompiler", **kw
) -> str:
clause = fetch_processed_single_clause(element, compiler)
return f"DAY({clause})"
# noinspection PyUnusedLocal
@compiles(extract_day_of_month, SqlaDialectName.ORACLE)
@compiles(extract_year, SqlaDialectName.POSTGRES)
def extract_day_of_month_extract(
element: "ClauseElement", compiler: "SQLCompiler", **kw
) -> str:
clause = fetch_processed_single_clause(element, compiler)
return f"EXTRACT(DAY FROM {clause})"
# noinspection PyUnusedLocal
@compiles(extract_day_of_month, SqlaDialectName.SQLITE)
def extract_day_of_month_strftime(
element: "ClauseElement", compiler: "SQLCompiler", **kw
) -> str:
clause = fetch_processed_single_clause(element, compiler)
return f"STRFTIME('%d', {clause})"