From 1a414d5764e970681efd398c99c697aed6c5a3f7 Mon Sep 17 00:00:00 2001 From: MauricioGarciaS <47052044+MauricioGarciaS@users.noreply.github.com> Date: Tue, 9 May 2023 11:35:50 +0200 Subject: [PATCH] fix(connector): Added checkpoints and sigterm handler (#1234) * fix(connector): fixed bug of cache dict size error * fix(connector): Added method to save state in s3 for redshift if sigterm arise * fix(connector): Added exit signal handler and checkpoint method * Added sslmode selection for connection to database, added use_ssl parameter for S3 connection --- ee/connectors/consumer_async.py | 18 +++++++-- ee/connectors/db/api.py | 54 ++++++++++++++++++++++++--- ee/connectors/utils/cache.py | 24 +++++++++++- ee/connectors/utils/signal_handler.py | 14 +++++++ 4 files changed, 98 insertions(+), 12 deletions(-) create mode 100644 ee/connectors/utils/signal_handler.py diff --git a/ee/connectors/consumer_async.py b/ee/connectors/consumer_async.py index 99ae5a087..2e6054606 100644 --- a/ee/connectors/consumer_async.py +++ b/ee/connectors/consumer_async.py @@ -1,15 +1,13 @@ -from numpy._typing import _16Bit from decouple import config, Csv from confluent_kafka import Consumer from datetime import datetime from collections import defaultdict -import json import asyncio from time import time, sleep from copy import deepcopy from msgcodec.msgcodec import MessageCodec -from msgcodec.messages import SessionStart, SessionEnd +from msgcodec.messages import SessionEnd from db.api import DBConnection from db.models import events_detailed_table_name, events_table_name, sessions_table_name from db.writer import insert_batch, update_batch @@ -18,6 +16,7 @@ from utils.cache import ProjectFilter as PF from utils import pg_client from psycopg2 import InterfaceError +from utils.signal_handler import signal_handler def process_message(msg, codec, sessions, batch, sessions_batch, interesting_sessions, interesting_events, EVENT_TYPE, projectFilter): if msg is None: @@ -174,6 +173,11 @@ async def main(): allowed_projects = config('PROJECT_IDS', default=None, cast=Csv(int)) project_filter = PF(allowed_projects) + try: + project_filter.load_checkpoint(db) + except Exception as e: + print('[WARN] Checkpoint not found') + print(repr(e)) codec = MessageCodec(filter_events) ssl_protocol = config('KAFKA_USE_SSL', default=True, cast=bool) consumer_settings = { @@ -191,7 +195,7 @@ async def main(): c_time = time() read_msgs = 0 - while True: + while signal_handler.KEEP_PROCESSING: msg = consumer.poll(1.0) process_message(msg, codec, sessions, batch, sessions_batch, sessions_events_selection, selected_events, EVENT_TYPE, project_filter) read_msgs += 1 @@ -199,10 +203,16 @@ async def main(): print(f'[INFO] {read_msgs} kafka messages read in {upload_rate} seconds') await insertBatch(deepcopy(sessions_batch), deepcopy(batch), db, sessions_table_name, table_name, EVENT_TYPE) consumer.commit() + try: + project_filter.save_checkpoint(db) + except Exception as e: + print("[Error] Error while saving checkpoint") + print(repr(e)) sessions_batch = [] batch = [] read_msgs = 0 c_time = time() + project_filter.terminate(db) diff --git a/ee/connectors/db/api.py b/ee/connectors/db/api.py index daeced903..4c449105b 100644 --- a/ee/connectors/db/api.py +++ b/ee/connectors/db/api.py @@ -4,11 +4,19 @@ from sqlalchemy.orm import sessionmaker, session from contextlib import contextmanager import logging from decouple import config as _config +from decouple import Choices from pathlib import Path +import io DATABASE = _config('CLOUD_SERVICE') +sslmode = _config('DB_SSLMODE', + cast=Choices(['disable', 'allow', 'prefer', 'require', 'verify-ca', 'verify-full']), + default='allow' +) if DATABASE == 'redshift': import pandas_redshift as pr + import botocore + use_ssl = _config('S3_USE_SSL', default=True, cast=bool) base_path = Path(__file__).parent.parent @@ -62,12 +70,14 @@ class DBConnection: host=cluster_info['HOST'], port=cluster_info['PORT'], user=cluster_info['USER'], - password=cluster_info['PASSWORD']) + password=cluster_info['PASSWORD'], + sslmode=sslmode) self.pdredshift.connect_to_s3(aws_access_key_id=_config('AWS_ACCESS_KEY_ID'), aws_secret_access_key=_config('AWS_SECRET_ACCESS_KEY'), bucket=_config('BUCKET'), - subdirectory=_config('SUBDIRECTORY', default=None)) + subdirectory=_config('SUBDIRECTORY', default=None), + use_ssl=use_ssl) self.CONNECTION_STRING = _config('CONNECTION_STRING').format( USER=cluster_info['USER'], @@ -76,14 +86,14 @@ class DBConnection: PORT=cluster_info['PORT'], DBNAME=cluster_info['DBNAME'] ) - self.engine = create_engine(self.CONNECTION_STRING) + self.engine = create_engine(self.CONNECTION_STRING, connect_args={'sslmode': sslmode}) elif config == 'clickhouse': self.CONNECTION_STRING = _config('CONNECTION_STRING').format( HOST=_config('HOST'), DATABASE=_config('DATABASE') ) - self.engine = create_engine(self.CONNECTION_STRING) + self.engine = create_engine(self.CONNECTION_STRING, connect_args={'sslmode': sslmode}) elif config == 'pg': self.CONNECTION_STRING = _config('CONNECTION_STRING').format( USER=_config('USER'), @@ -92,7 +102,7 @@ class DBConnection: PORT=_config('PORT'), DATABASE=_config('DATABASE') ) - self.engine = create_engine(self.CONNECTION_STRING) + self.engine = create_engine(self.CONNECTION_STRING, connect_args={'sslmode': sslmode}) elif config == 'bigquery': pass elif config == 'snowflake': @@ -104,7 +114,7 @@ class DBConnection: DBNAME = _config('DBNAME'), WAREHOUSE = _config('WAREHOUSE') ) - self.engine = create_engine(self.CONNECTION_STRING) + self.engine = create_engine(self.CONNECTION_STRING, connect_args={'sslmode': sslmode}) else: raise ValueError("This db configuration doesn't exist. Add into keys file.") @@ -146,6 +156,38 @@ class DBConnection: self.close() self.__init__(config=self.config) + def save_binary(self, binary_data, name, **kwargs): + if self.config == 'redshift': + try: + self.pdredshift.core.s3.Object(bucket_name=self.pdredshift.core.s3_bucket_var, + key=self.pdredshift.core.s3_subdirectory_var + name).put( + Body=binary_data, **kwargs) + print(f'[INFO] Content saved: {name}') + except botocore.exceptions.ClientError as err: + print(repr(err)) + + def load_binary(self, name): + if self.config == 'redshift': + try: + s3_object = self.pdredshift.core.s3.Object(bucket_name=self.pdredshift.core.s3_bucket_var, + key=self.pdredshift.core.s3_subdirectory_var + name) + f = io.BytesIO() + s3_object.download_fileobj(f) + print(f'[INFO] Content downloaded: {name}') + return f + except botocore.exceptions.ClientError as err: + print(repr(err)) + + def delete_binary(self, name): + if self.config == 'redshift': + try: + s3_object = self.pdredshift.core.s3.Object(bucket_name=self.pdredshift.core.s3_bucket_var, + key=self.pdredshift.core.s3_subdirectory_var + name) + s3_object.delete() + print(f'[INFO] s3 object {name} deleted') + except botocore.exceptions.ClientError as err: + print(repr(err)) + def close(self): if self.config == 'redshift': self.pdredshift.close_up_shop() diff --git a/ee/connectors/utils/cache.py b/ee/connectors/utils/cache.py index 24dbda400..2f4050a52 100644 --- a/ee/connectors/utils/cache.py +++ b/ee/connectors/utils/cache.py @@ -1,7 +1,7 @@ from utils.pg_client import PostgresClient -from queue import Queue from decouple import config from time import time +import json def _project_from_session(sessionId): @@ -97,8 +97,28 @@ class ProjectFilter: def handle_clean(self): """Verifies and execute cleanup if needed""" - if len(self.filter)==0: + if len(self.filter) == 0: return elif len(self.cache) > self.max_cache_size: self.cleanup() + def load_checkpoint(self, db): + file = db.load_binary(name='checkpoint') + checkpoint = json.loads(file.getvalue().decode('utf-8')) + file.close() + self.cache = checkpoint['cache'] + self.to_clean = checkpoint['to_clean'] + self.cached_sessions.session_project = checkpoint['cached_sessions'] + + def save_checkpoint(self, db): + checkpoint = { + 'cache': self.cache, + 'to_clean': self.to_clean, + 'cached_sessions': self.cached_sessions.session_project, + } + db.save_binary(binary_data=json.dumps(checkpoint).encode('utf-8'), name='checkpoint') + + def terminate(self, db): + # self.save_checkpoint(db) + db.close() + diff --git a/ee/connectors/utils/signal_handler.py b/ee/connectors/utils/signal_handler.py new file mode 100644 index 000000000..eabdbf29c --- /dev/null +++ b/ee/connectors/utils/signal_handler.py @@ -0,0 +1,14 @@ +import signal + +class SignalHandler: + KEEP_PROCESSING = True + def __init__(self): + signal.signal(signal.SIGINT, self.exit_gracefully) + signal.signal(signal.SIGTERM, self.exit_gracefully) + + def exit_gracefully(self, signum, frame): + print(f"Exiting gracefully with signal {signum}") + self.KEEP_PROCESSING = False + + +signal_handler = SignalHandler()