#!/usr/bin/env python3 # -*- encoding: utf-8 -*- # # This script is licensed under GNU GPL version 2.0 or above # (c) 2024 Antonio J. Delgado """Discover Mastodon servers by looking at public timelines""" import sys import os import logging from logging.handlers import SysLogHandler import sqlite3 import time import re import click import click_config_file import requests class DiscoverMastodonServers: '''Class to Discover Mastodon Servers''' def __init__(self, **kwargs): self.config = kwargs if 'log_file' not in kwargs or kwargs['log_file'] is None: self.config['log_file'] = os.path.join( os.environ.get( 'HOME', os.environ.get( 'USERPROFILE', os.getcwd() ) ), 'log', 'discover-mastodon-servers.log' ) self._init_log() self.session = requests.Session() self.session.proxies.update({ 'https': self.config['proxy']}) self.conn = sqlite3.connect(self.config['database_file']) self.read_db() if len(self.servers) == 0: self._log.debug("Adding initial server.") self.servers[self.config['initial_server']] = { "name": self.config['initial_server'], "last_update": time.time(), "private": False, "peers": True, "timeline": True, } new_servers_count = 1 while new_servers_count > 0: new_servers_count = self.discover() def get_timeline(self, server): '''Get the data of a public timeline for a given server''' return self.get_path(server, 'api/v1/timelines/public?remote=true&limit=40') def get_path(self, server, endpoint): '''Get the data of an endpoint of a server''' data = None try: result = self.session.get( f"https://{server['name']}/{endpoint}", timeout=10 ) server['status'] = result.status_code server['state'] = 'OK' if result.status_code < 400: if 'Content-Type' in result.headers: if 'application/json' in result.headers['Content-Type']: data = result.json() if 'error' not in data: return data else: server['state'] = 'Error' self._log.debug( "Server '%s' didn't reply with JSON data.", server['name'] ) else: server['state'] = 'Error' self._log.debug( "Server '%s' didn't return Content-Type header. Headers: '%s'. Content returned: '%s'", server['name'], result.headers, result.content ) else: server['state'] = 'Error' self._log.debug( "Server '%s' returned error code %s.", server['name'], result.status_code ) except requests.exceptions.ReadTimeout as error: server['state'] = 'Error' self._log.warning( "Server '%s' didn't respond on time. %s", server['name'], error ) except requests.exceptions.SSLError as error: server['state'] = 'SSL Error' self._log.warning( "Server '%s' don't have a valid SSL certificate. %s", server['name'], error ) except requests.exceptions.ConnectionError as error: server['state'] = 'Error' self._log.warning( "Server '%s' connection failed. %s", server['name'], error ) except requests.exceptions.TooManyRedirects as error: server['state'] = 'Error' self._log.warning( "Server '%s' redirected too many times. %s", server['name'], error ) except Exception as error: server['state'] = 'Error' self._log.warning( "Error fetching endpoint '%s' from server '%s'. %s", endpoint, server['name'], error ) return data def get_instance_info(self, server): '''Get all server information''' instance = self.get_path(server['name'], '/api/v1/instance') if instance: server['instance'] = instance server['directory'] = [] offset=0 while len(server['directory']) == 0: directory = self.get_path( server, f"/api/v1/directory?limit=80&offset={offset}" ) if directory: server['directory'] = server['directory'] + directory offset += 80 def test_banned_server(self, server_name): '''Check if a server name match agains any banned regular expressions''' for banned in self.config['regexp_banned_host']: match = re.search(banned, server_name) if match: self._log.debug( "Regexp '%s' match server '%s', banned.", banned, server_name ) return True return False def discover(self): '''Discover new servers''' all_servers = [] new_servers_count = 0 for server_name, server in self.servers.items(): all_servers.append(server_name) if not self.test_banned_server(server_name): if 'state' not in server: server['state'] = 'Unknown' if 'status' not in server: server['status'] = 0 if 'peers' not in server: server['peers'] = True if 'timeline' not in server: server['timeline'] = True if not server['private'] and 'Error' not in server['state']: self._log.debug("Fetching peers of the server '%s'", server_name) data = self.get_path(server, 'api/v1/instance/peers') if data: for new_server in data: if ((not self.test_banned_server(new_server)) and (new_server not in self.servers) and (new_server not in all_servers)): new_servers_count += 1 self._log.debug( "Adding new server '%s' from peers", new_server ) all_servers.append(new_server) self.write_record( { "name": new_server, "last_update": time.time(), "private": False, "peers": True, "timeline": True, } ) else: server['peers'] = False self._log.debug("Fetching public timeline in server '%s'", server_name) data = self.get_timeline(server) if data: for item in data: if 'uri' in item: match_server = re.match(r'https?://([^/]*)/', item['uri']) if match_server: new_server = match_server.group(1) if not self.test_banned_server(new_server) and new_server not in all_servers: new_server_obj = { "name": new_server } data = self.get_timeline(new_server_obj) if data: new_servers_count += 1 self._log.debug( "Adding new server '%s' from timeline", new_server ) all_servers.append(new_server) new_server_obj['private'] = False else: new_server_obj['private'] = True self.write_record(new_server_obj) else: # Item in public timeline don't have an URI self._log.debug( "Item don't have URI. %s", item ) else: server['timeline'] = False self.write_record(server) return new_servers_count def write_record(self, record, table='servers'): '''Write record to a table''' if 'state' not in record: record['state'] = 'Unknown' if 'status' not in record: record['status'] = 0 if 'peers' not in record: record['peers'] = True if 'timeline' not in record: record['timeline'] = True if 'last_update' not in record: record['last_update'] = time.time() cur = self.conn.cursor() result_select = cur.execute(f""" SELECT name FROM {table} WHERE name = '{record['name']}' """) if len(result_select.fetchall()) > 0: self._log.debug('Record exists, updating.') query = f"UPDATE {table} SET " count = 0 for key in record.keys(): if count == 0: query += f"{key} = :{key} " else: query += f",{key} = :{key} " count += 1 query += "WHERE name = :name" else: self._log.debug('Record doesn\'t exist, inserting.') query = f"INSERT INTO {table} VALUES (:" + ",:".join(record.keys()) + ")" self._log.debug("Writing record '%s'...", record ) try: result_update = cur.execute(query, record) self._log.debug("Added record %s.", result_update.lastrowid) except Exception as error: self._log.error("Error running query '%s' with record '%s'. %s", query, record, error) sys.exit(1) cur.close() self.conn.commit() def read_db(self): '''Read database file''' cur = self.conn.cursor() query = """CREATE TABLE IF NOT EXISTS servers( name TEXT PRIMARY KEY, last_update REAL, private INT, peers INT, timeline INT, status INT, state TEXT )""" try: cur.execute(query) except Exception as error: self._log.error("Error running query to create table '%s'. %s", query, error) sys.exit(2) query = "SELECT * FROM servers ORDER BY last_update DESC" try: result_select = cur.execute(query) except Exception as error: self._log.error("Error running query to list servers '%s'. %s", query, error) sys.exit(3) self.servers = {} for item in result_select.fetchall(): self.servers[item[0]] = { "name": item[0], "last_update": item[1], "private": item[2] } self._log.debug("There are %s servers in the database.", len(self.servers)) self.conn.commit() def _init_log(self): ''' Initialize log object ''' self._log = logging.getLogger("discover-mastodon-servers") self._log.setLevel(logging.DEBUG) sysloghandler = SysLogHandler() sysloghandler.setLevel(logging.DEBUG) self._log.addHandler(sysloghandler) streamhandler = logging.StreamHandler(sys.stdout) streamhandler.setLevel( logging.getLevelName(self.config.get("debug_level", 'INFO')) ) self._log.addHandler(streamhandler) if 'log_file' in self.config: log_file = self.config['log_file'] else: home_folder = os.environ.get( 'HOME', os.environ.get('USERPROFILE', '') ) log_folder = os.path.join(home_folder, "log") log_file = os.path.join(log_folder, "discover-mastodon-servers.log") if not os.path.exists(os.path.dirname(log_file)): os.mkdir(os.path.dirname(log_file)) filehandler = logging.handlers.RotatingFileHandler( log_file, maxBytes=102400000 ) # create formatter formatter = logging.Formatter( '%(asctime)s %(name)-12s %(levelname)-8s %(message)s' ) filehandler.setFormatter(formatter) filehandler.setLevel(logging.DEBUG) self._log.addHandler(filehandler) return True @click.command() @click.option( "--debug-level", "-d", default="INFO", type=click.Choice( ["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"], case_sensitive=False, ), help='Set the debug level for the standard output.' ) @click.option('--log-file', '-l', help="File to store all debug messages.") # @click.option("--dummy","-n", is_flag=True, # help="Don't do anything, just show what would be done.") @click.option( '--initial-server', '-i', default='mastodon.social', help='First Mastodon server to reach to read public timeline and discover others.' ) @click.option('--proxy', '-p', help='Proxy URL to use.') @click.option( '--database-file', '-d', default='mastodon-servers.db', help='File with the database of results.' ) @click.option( '--regexp-banned-host', '-r', multiple=True, help='Regular expression for banned host names.' ) @click_config_file.configuration_option() def __main__(**kwargs): return DiscoverMastodonServers(**kwargs) if __name__ == "__main__": __main__()