import csv import re import io import datetime from fastapi import FastAPI, Request, HTTPException, Query, UploadFile, File, Form from fastapi.responses import HTMLResponse, RedirectResponse from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates import uvicorn from pathlib import Path from typing import List, Dict, Optional, Set from tinydb import TinyDB, Query as TinyDBQuery # Initialize FastAPI app app = FastAPI(title="Domain Viewer", description="Upload and view domains and DNS records") # Setup templates templates = Jinja2Templates(directory="templates") # Setup TinyDB with configurable path import os DB_DIR = os.environ.get('DB_DIR', '.') DB_PATH = Path(f"{DB_DIR}/domains_db.json") db = TinyDB(DB_PATH) domains_table = db.table('domains') dns_records_table = db.table('dns_records') uploads_table = db.table('uploads') # Process a domain entry and return its components def process_domain_entry(domain_entry): # Remove trailing dot if present if domain_entry.endswith('.'): domain_entry = domain_entry[:-1] if domain_entry: # Store only the full domain name without splitting domain_info = { "full_domain": domain_entry } return domain_info return None # Upload CSV file and store in database async def process_csv_upload(file_content, upload_id, description=None): domains_to_insert = [] dns_records_to_insert = [] unique_domains = set() timestamp = datetime.datetime.now().isoformat() row_count = 0 try: # Read CSV content file_text = file_content.decode('utf-8') csv_reader = csv.reader(io.StringIO(file_text)) # Print first few lines for debugging preview_lines = file_text.split('\n')[:5] print(f"CSV preview (first 5 lines):") for i, line in enumerate(preview_lines): print(f" Line {i+1}: {line}") for row_num, row in enumerate(csv_reader, 1): row_count += 1 if not row: # Skip empty rows print(f"Skipping empty row at line {row_num}") continue # Extract domain from first column domain_entry = row[0] # Process domain for domains table domain_info = process_domain_entry(domain_entry) if domain_info: # Create a unique key to avoid duplicates within this upload unique_key = domain_info['full_domain'] if unique_key not in unique_domains: unique_domains.add(unique_key) # Add upload metadata domain_info['upload_id'] = upload_id domain_info['timestamp'] = timestamp domains_to_insert.append(domain_info) else: print(f"Warning: Could not process domain entry at line {row_num}: {domain_entry}") # Process DNS record if we have enough fields if len(row) >= 5: domain = row[0] ttl = row[1] record_class = row[2] record_type = row[3] record_data = ','.join(row[4:]) # Join remaining parts as record data # Remove trailing dot from domain if present if domain.endswith('.'): domain = domain[:-1] # Parse domain components domain_parts = domain.split('.') # Create entry entry = { "domain": domain, "ttl": ttl, "record_class": record_class, "record_type": record_type, "record_data": record_data, "upload_id": upload_id, "timestamp": timestamp } # Add special handling for service records if len(domain_parts) > 0 and domain_parts[0].startswith('_'): # Service records like _dmarc entry["service"] = domain_parts[0] dns_records_to_insert.append(entry) print(f"Processed {row_count} rows from CSV") print(f"Records to insert: {len(domains_to_insert)} domains, {len(dns_records_to_insert)} DNS records") # Insert data into tables if domains_to_insert: print(f"Inserting {len(domains_to_insert)} domains into database") domains_table.insert_multiple(domains_to_insert) else: print("No domains to insert") if dns_records_to_insert: print(f"Inserting {len(dns_records_to_insert)} DNS records into database") dns_records_table.insert_multiple(dns_records_to_insert) else: print("No DNS records to insert") return len(domains_to_insert), len(dns_records_to_insert) except Exception as e: import traceback print(f"Error processing CSV file: {e}") print(traceback.format_exc()) return 0, 0 # Load domains from database - deduplicated by full domain name, with optional base domain filtering def load_domains(specific_upload_id: str = None, base_domains_only: bool = False) -> List[Dict]: try: domains = domains_table.all() # If a specific upload ID is provided, only show domains from that upload if specific_upload_id: domains = [d for d in domains if d.get('upload_id') == specific_upload_id] if not base_domains_only: return domains # Add the base_domain field to each domain for domain in domains: domain['base_domain'] = extract_base_domain(domain.get('full_domain', '')) # Sort by timestamp in descending order (newest first) domains.sort(key=lambda x: x.get('timestamp', ''), reverse=True) # Create a dictionary to track unique domains unique_domains = {} base_domains_set = set() # First pass: collect all base domains if base_domains_only: for domain in domains: base_domains_set.add(domain.get('base_domain', '')) for domain in domains: # If base_domains_only is True, only keep domains that are base domains themselves if base_domains_only: full_domain = domain.get('full_domain', '') base_domain = domain.get('base_domain', '') if full_domain != base_domain: continue # Create a unique key based on the full domain name unique_key = domain.get('full_domain', '') # Only keep the most recent entry for each unique domain if unique_key and unique_key not in unique_domains: domain['is_latest'] = True unique_domains[unique_key] = domain # Return the deduplicated list return list(unique_domains.values()) except Exception as e: print(f"Error loading domains from database: {e}") return [] # Load DNS entries from database - with optional deduplication def load_dns_entries(specific_upload_id: str = None, deduplicate: bool = False) -> List[Dict]: try: entries = dns_records_table.all() # If a specific upload ID is provided, only show records from that upload if specific_upload_id: entries = [e for e in entries if e.get('upload_id') == specific_upload_id] # Sort by timestamp in descending order (newest first) entries.sort(key=lambda x: x.get('timestamp', ''), reverse=True) # If deduplication is requested, only keep the most recent entry for each unique combination if deduplicate: # Create a dictionary to track unique entries (most recent only) unique_entries = {} for entry in entries: # Create a unique key based on domain, class, type, TTL, and data unique_key = f"{entry.get('domain')}:{entry.get('record_class')}:{entry.get('record_type')}:{entry.get('ttl')}:{entry.get('record_data')}" # Only keep the most recent entry for each unique combination if unique_key not in unique_entries: # Mark as most recent entry entry['is_latest'] = True unique_entries[unique_key] = entry # Return the deduplicated list with only the most recent entries return list(unique_entries.values()) else: # No deduplication - return all entries return entries except Exception as e: print(f"Error loading DNS records from database: {e}") return [] # List of known multi-part TLDs MULTI_PART_TLDS = [ 'co.uk', 'org.uk', 'me.uk', 'ac.uk', 'gov.uk', 'net.uk', 'sch.uk', 'com.au', 'net.au', 'org.au', 'edu.au', 'gov.au', 'asn.au', 'id.au', 'co.nz', 'net.nz', 'org.nz', 'govt.nz', 'ac.nz', 'school.nz', 'geek.nz', 'com.sg', 'edu.sg', 'gov.sg', 'net.sg', 'org.sg', 'per.sg', 'co.za', 'org.za', 'web.za', 'net.za', 'gov.za', 'ac.za', 'com.br', 'net.br', 'org.br', 'gov.br', 'edu.br', 'co.jp', 'ac.jp', 'go.jp', 'or.jp', 'ne.jp', 'gr.jp', 'co.in', 'firm.in', 'net.in', 'org.in', 'gen.in', 'ind.in', 'edu.cn', 'gov.cn', 'net.cn', 'org.cn', 'com.cn', 'ac.cn', 'com.mx', 'net.mx', 'org.mx', 'edu.mx', 'gob.mx' ] # Extract the base domain (SLD+TLD) from a full domain name def extract_base_domain(domain: str) -> str: if not domain: return domain # Remove trailing dot if present if domain.endswith('.'): domain = domain[:-1] parts = domain.split('.') # Check if the domain has enough parts if len(parts) <= 1: return domain # Check for known multi-part TLDs first for tld in MULTI_PART_TLDS: tld_parts = tld.split('.') if len(parts) > len(tld_parts) and '.'.join(parts[-len(tld_parts):]) == tld: # The domain has a multi-part TLD, extract SLD + multi-part TLD return parts[-len(tld_parts)-1] + '.' + tld # Default case: extract last two parts if len(parts) > 1: return '.'.join(parts[-2:]) return domain # Get all unique base domains from the database def get_unique_base_domains(specific_upload_id: str = None) -> List[Dict]: try: domains = domains_table.all() # If a specific upload ID is provided, only show domains from that upload if specific_upload_id: domains = [d for d in domains if d.get('upload_id') == specific_upload_id] # Add the base_domain field to each domain for domain in domains: domain['base_domain'] = extract_base_domain(domain.get('full_domain', '')) # Sort by timestamp in descending order (newest first) domains.sort(key=lambda x: x.get('timestamp', ''), reverse=True) # Create dictionaries to track unique base domains unique_base_domains = {} # Process each domain and keep only unique base domains for domain in domains: base_domain = domain.get('base_domain', '') # Skip if no base domain if not base_domain: continue # Check if this base domain has been seen before if base_domain not in unique_base_domains: # Create a new entry for this base domain - with simplified fields base_domain_entry = { 'domain': base_domain, 'timestamp': domain.get('timestamp') } unique_base_domains[base_domain] = base_domain_entry # Return the list of unique base domains return list(unique_base_domains.values()) except Exception as e: print(f"Error getting unique base domains: {e}") return [] # Get unique values for filter dropdowns def get_unique_values(entries: List[Dict]) -> Dict[str, Set]: unique_values = { "record_type": set(), "record_class": set() } for entry in entries: for key in unique_values.keys(): if key in entry and entry[key]: unique_values[key].add(entry[key]) # Convert sets to sorted lists return {k: sorted(list(v)) for k, v in unique_values.items()} # Get all uploads def get_uploads(): uploads = uploads_table.all() # Sort uploads by timestamp (newest first) uploads.sort(key=lambda x: x.get('timestamp', ''), reverse=True) return uploads # Delete an upload and its associated data def delete_upload(upload_id): try: # Remove the upload from uploads table Upload = TinyDBQuery() uploads_table.remove(Upload.id == upload_id) # Remove associated domain and DNS records Domain = TinyDBQuery() domains_table.remove(Domain.upload_id == upload_id) DNSRecord = TinyDBQuery() dns_records_table.remove(DNSRecord.upload_id == upload_id) return True except Exception as e: print(f"Error deleting upload {upload_id}: {e}") return False # Routes @app.get("/", response_class=HTMLResponse) async def home( request: Request, upload_id: Optional[str] = None, base_domains_only: Optional[bool] = False ): """Home page with upload form and domain listing""" domains = load_domains(upload_id, base_domains_only) uploads = get_uploads() return templates.TemplateResponse( "index.html", { "request": request, "domains": domains, "uploads": uploads, "base_domains_only": base_domains_only } ) @app.get("/delete-upload/{upload_id}", response_class=RedirectResponse) async def delete_upload_route(upload_id: str): """Delete an upload and all associated records""" success = delete_upload(upload_id) if not success: raise HTTPException(status_code=500, detail="Failed to delete upload") # Redirect back to home page return RedirectResponse(url="/", status_code=303) @app.post("/upload", response_class=RedirectResponse) async def upload_csv(request: Request, file: UploadFile = File(...), description: str = Form(None)): """Handle file upload""" try: # Read file content content = await file.read() # Ensure content is not empty if not content or len(content) == 0: raise ValueError("Uploaded file is empty") # Generate a unique ID for this upload with timestamp and a random suffix for extra uniqueness now = datetime.datetime.now() import random random_suffix = ''.join(random.choices('abcdefghijklmnopqrstuvwxyz0123456789', k=6)) upload_id = f"upload_{now.strftime('%Y%m%d_%H%M%S')}_{random_suffix}" print(f"Processing upload: ID={upload_id}, Filename={file.filename}, Content size={len(content)} bytes") # Process the CSV content domains_count, records_count = await process_csv_upload(content, upload_id, description) print(f"Upload processed: {domains_count} domains and {records_count} DNS records inserted") if domains_count == 0 and records_count == 0: print("WARNING: No records were inserted. Content may be invalid or empty.") # Try to decode and print the first few lines for debugging try: preview = content.decode('utf-8')[:500] print(f"File preview: {preview}") except: print("Could not decode file content for preview") # Store upload information with file hash to identify content changes import hashlib content_hash = hashlib.md5(content).hexdigest() upload_info = { "id": upload_id, "filename": file.filename, "description": description, "timestamp": datetime.datetime.now().isoformat(), "domains_count": domains_count, "records_count": records_count, "content_hash": content_hash } uploads_table.insert(upload_info) # Redirect back to home page return RedirectResponse(url="/", status_code=303) except Exception as e: print(f"Error in upload_csv: {e}") return {"error": str(e)} @app.get("/dns-records", response_class=HTMLResponse) async def dns_records( request: Request, upload_id: Optional[str] = None, record_type: Optional[str] = None, record_class: Optional[str] = None, domain: Optional[str] = None, deduplicate: Optional[bool] = True # Default to showing only unique latest entries ): """DNS Records page with filtering""" # Get all entries first, based on upload_id if provided, with deduplication option entries = load_dns_entries(upload_id, deduplicate) # Apply additional filters if provided if record_type: entries = [e for e in entries if e.get("record_type") == record_type] if record_class: entries = [e for e in entries if e.get("record_class") == record_class] if domain: entries = [e for e in entries if domain.lower() in e.get("domain", "").lower()] # Get unique values for filter dropdowns from all entries (not filtered) all_entries = load_dns_entries(upload_id, deduplicate=False) unique_values = get_unique_values(all_entries) uploads = get_uploads() return templates.TemplateResponse( "dns_records.html", { "request": request, "entries": entries, "unique_values": unique_values, "uploads": uploads, "deduplicate": deduplicate } ) # API Routes @app.get("/api/uploads", response_model=List[Dict]) async def get_all_uploads(): """API endpoint that returns all uploads""" return get_uploads() @app.get("/api/domains", response_model=List[Dict]) async def get_domains( upload_id: Optional[str] = None, base_domains_only: Optional[bool] = False ): """API endpoint that returns all domains with optional filtering""" # The load_domains function handles deduplication and filtering domains = load_domains(upload_id, base_domains_only) return domains @app.get("/api/base-domains", response_model=List[Dict]) async def get_base_domains(upload_id: Optional[str] = None): """API endpoint that returns only unique base domains""" # Get only the unique base domains base_domains = get_unique_base_domains(upload_id) return base_domains @app.get("/api/domains/{domain}", response_model=List[Dict]) async def get_domains_by_name(domain: str, upload_id: Optional[str] = None): """API endpoint that returns domains matching a specific domain name with optional filter by upload_id""" # Get domains, already deduplicated and optionally filtered by upload_id all_domains = load_domains(upload_id) # Filter by domain name filtered = [item for item in all_domains if domain.lower() in item["full_domain"].lower()] if not filtered: raise HTTPException(status_code=404, detail=f"No domains found matching: {domain}") return filtered @app.get("/api/dns", response_model=List[Dict]) async def get_dns_entries( record_type: Optional[str] = None, record_class: Optional[str] = None, domain: Optional[str] = None, upload_id: Optional[str] = None, deduplicate: Optional[bool] = True ): """API endpoint that returns filtered DNS entries with optional deduplication""" # Get entries - if upload_id is specified, only those entries are returned entries = load_dns_entries(upload_id, deduplicate) # Apply additional filters if provided if record_type: entries = [e for e in entries if e.get("record_type") == record_type] if record_class: entries = [e for e in entries if e.get("record_class") == record_class] if domain: entries = [e for e in entries if domain.lower() in e.get("domain", "").lower()] return entries @app.get("/api/dns/types", response_model=Dict[str, List]) async def get_unique_filter_values(upload_id: Optional[str] = None): """API endpoint that returns unique values for filters""" # Get entries - if upload_id is specified, only those entries are returned entries = load_dns_entries(upload_id) return get_unique_values(entries) # Create templates directory and HTML file Path("templates").mkdir(exist_ok=True) if __name__ == "__main__": uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)