# Copyright 2019 SUSE Linux GmbH
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import logging
from typing import Any
from oslo_serialization import jsonutils
from osprofiler.drivers import base
from osprofiler import exc
LOG = logging.getLogger(__name__)
[docs]
class SQLAlchemyDriver(base.Driver):
def __init__(
self,
connection_str: str,
project: str | None = None,
service: str | None = None,
host: str | None = None,
**kwargs: Any,
) -> None:
super().__init__(
connection_str, project=project, service=service, host=host
)
try:
from sqlalchemy import create_engine # type: ignore[import-not-found]
from sqlalchemy import Table, MetaData, Column
from sqlalchemy import String, JSON, Integer
except ImportError:
LOG.exception("To use this command, install 'SQLAlchemy'")
else:
self._metadata = MetaData()
self._data_table = Table(
"data",
self._metadata,
Column("id", Integer, primary_key=True),
# timestamp - date/time of the trace point
Column("timestamp", String(26), index=True),
# base_id - uuid common for all notifications related to
# one trace
Column("base_id", String(255), index=True),
# parent_id - uuid of parent element in trace
Column("parent_id", String(255), index=True),
# trace_id - uuid of current element in trace
Column("trace_id", String(255), index=True),
Column("project", String(255), index=True),
Column("host", String(255), index=True),
Column("service", String(255), index=True),
# name - trace point name
Column("name", String(255), index=True),
Column("data", JSON),
)
# we don't want to kill any service that does use osprofiler
try:
self._engine = create_engine(connection_str)
self._conn = self._engine.connect()
# FIXME(toabctl): Not the best idea to create the table on every
# startup when using the sqlalchemy driver...
self._metadata.create_all(self._engine, checkfirst=True)
except Exception:
LOG.exception(
"Failed to create engine/connection and setup "
"intial database tables"
)
[docs]
@classmethod
def get_name(cls) -> str:
return "sqlalchemy"
[docs]
def notify(
self, info: dict[str, Any], context: Any = None, **kwargs: Any
) -> None:
"""Write a notification the the database"""
data = info.copy()
base_id = data.pop("base_id", None)
timestamp = data.pop("timestamp", None)
parent_id = data.pop("parent_id", None)
trace_id = data.pop("trace_id", None)
project = data.pop("project", self.project)
host = data.pop("host", self.host)
service = data.pop("service", self.service)
name = data.pop("name", None)
try:
ins = self._data_table.insert().values(
timestamp=timestamp,
base_id=base_id,
parent_id=parent_id,
trace_id=trace_id,
project=project,
service=service,
host=host,
name=name,
data=jsonutils.dumps(data),
)
self._conn.execute(ins)
except Exception:
LOG.exception(
"Can not store osprofiler tracepoint %s (base id %s)",
trace_id,
base_id,
)
[docs]
def list_traces(
self, fields: set[str] | None = None
) -> list[dict[str, Any]]:
try:
from sqlalchemy.sql import select # type: ignore[import-not-found]
except ImportError:
raise exc.CommandError(
"To use this command, you should install 'SQLAlchemy'"
)
fields = set(fields or self.default_trace_fields)
stmt = select([self._data_table])
seen_ids: set[str] = set()
result: list[dict[str, Any]] = []
traces = self._conn.execute(stmt).fetchall()
for trace in traces:
if trace["base_id"] not in seen_ids:
seen_ids.add(trace["base_id"])
result.append(
{
key: value
for key, value in trace.items()
if key in fields
}
)
return result
[docs]
def get_report(self, base_id: str) -> dict[str, Any]:
try:
from sqlalchemy.sql import select
except ImportError:
raise exc.CommandError(
"To use this command, you should install 'SQLAlchemy'"
)
stmt = select([self._data_table]).where(
self._data_table.c.base_id == base_id
)
results = self._conn.execute(stmt).fetchall()
for n in results:
timestamp = n["timestamp"]
trace_id = n["trace_id"]
parent_id = n["parent_id"]
name = n["name"]
project = n["project"]
service = n["service"]
host = n["host"]
data = jsonutils.loads(n["data"])
self._append_results(
trace_id,
parent_id,
name,
project,
service,
host,
timestamp,
data,
)
return self._parse_results()