domainapi/main.py

650 lines
No EOL
24 KiB
Python

import csv
import re
import io
import datetime
from fastapi import FastAPI, Request, HTTPException, Query, UploadFile, File, Form
from fastapi.responses import HTMLResponse, RedirectResponse, Response
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
import uvicorn
from pathlib import Path
from typing import List, Dict, Optional, Set
from tinydb import TinyDB, Query as TinyDBQuery
# Initialize FastAPI app
app = FastAPI(title="Domain Viewer", description="Upload and view domains and DNS records")
# Setup templates
templates = Jinja2Templates(directory="templates")
# Setup TinyDB with configurable path
import os
DB_DIR = os.environ.get('DB_DIR', '.')
DB_PATH = Path(f"{DB_DIR}/domains_db.json")
db = TinyDB(DB_PATH)
domains_table = db.table('domains')
dns_records_table = db.table('dns_records')
uploads_table = db.table('uploads')
# Process a domain entry and return its components
def process_domain_entry(domain_entry):
# Remove trailing dot if present
if domain_entry.endswith('.'):
domain_entry = domain_entry[:-1]
if domain_entry:
# Store only the full domain name without splitting
domain_info = {
"full_domain": domain_entry
}
return domain_info
return None
# Upload CSV file and store in database
async def process_csv_upload(file_content, upload_id, description=None):
domains_to_insert = []
dns_records_to_insert = []
unique_domains = set()
timestamp = datetime.datetime.now().isoformat()
row_count = 0
try:
# Read CSV content
file_text = file_content.decode('utf-8')
csv_reader = csv.reader(io.StringIO(file_text))
# Print first few lines for debugging
preview_lines = file_text.split('\n')[:5]
print(f"CSV preview (first 5 lines):")
for i, line in enumerate(preview_lines):
print(f" Line {i+1}: {line}")
for row_num, row in enumerate(csv_reader, 1):
row_count += 1
if not row: # Skip empty rows
print(f"Skipping empty row at line {row_num}")
continue
# Extract domain from first column
domain_entry = row[0]
# Process domain for domains table
domain_info = process_domain_entry(domain_entry)
if domain_info:
# Create a unique key to avoid duplicates within this upload
unique_key = domain_info['full_domain']
if unique_key not in unique_domains:
unique_domains.add(unique_key)
# Add upload metadata
domain_info['upload_id'] = upload_id
domain_info['timestamp'] = timestamp
domains_to_insert.append(domain_info)
else:
print(f"Warning: Could not process domain entry at line {row_num}: {domain_entry}")
# Process DNS record if we have enough fields
if len(row) >= 5:
domain = row[0]
ttl = row[1]
record_class = row[2]
record_type = row[3]
record_data = ','.join(row[4:]) # Join remaining parts as record data
# Remove trailing dot from domain if present
if domain.endswith('.'):
domain = domain[:-1]
# Parse domain components
domain_parts = domain.split('.')
# Create entry
entry = {
"domain": domain,
"ttl": ttl,
"record_class": record_class,
"record_type": record_type,
"record_data": record_data,
"upload_id": upload_id,
"timestamp": timestamp
}
# Add special handling for service records
if len(domain_parts) > 0 and domain_parts[0].startswith('_'): # Service records like _dmarc
entry["service"] = domain_parts[0]
dns_records_to_insert.append(entry)
print(f"Processed {row_count} rows from CSV")
print(f"Records to insert: {len(domains_to_insert)} domains, {len(dns_records_to_insert)} DNS records")
# Insert data into tables
if domains_to_insert:
print(f"Inserting {len(domains_to_insert)} domains into database")
domains_table.insert_multiple(domains_to_insert)
else:
print("No domains to insert")
if dns_records_to_insert:
print(f"Inserting {len(dns_records_to_insert)} DNS records into database")
dns_records_table.insert_multiple(dns_records_to_insert)
else:
print("No DNS records to insert")
return len(domains_to_insert), len(dns_records_to_insert)
except Exception as e:
import traceback
print(f"Error processing CSV file: {e}")
print(traceback.format_exc())
return 0, 0
# Load domains from database - deduplicated by full domain name, with optional base domain filtering
def load_domains(specific_upload_id: str = None, base_domains_only: bool = False) -> List[Dict]:
try:
domains = domains_table.all()
# If a specific upload ID is provided, only show domains from that upload
if specific_upload_id:
domains = [d for d in domains if d.get('upload_id') == specific_upload_id]
if not base_domains_only:
return domains
# Add the base_domain field to each domain
for domain in domains:
domain['base_domain'] = extract_base_domain(domain.get('full_domain', ''))
# Sort by timestamp in descending order (newest first)
domains.sort(key=lambda x: x.get('timestamp', ''), reverse=True)
# Create a dictionary to track unique domains
unique_domains = {}
base_domains_set = set()
# First pass: collect all base domains
if base_domains_only:
for domain in domains:
base_domains_set.add(domain.get('base_domain', ''))
for domain in domains:
# If base_domains_only is True, only keep domains that are base domains themselves
if base_domains_only:
full_domain = domain.get('full_domain', '')
base_domain = domain.get('base_domain', '')
if full_domain != base_domain:
continue
# Create a unique key based on the full domain name
unique_key = domain.get('full_domain', '')
# Only keep the most recent entry for each unique domain
if unique_key and unique_key not in unique_domains:
domain['is_latest'] = True
unique_domains[unique_key] = domain
# Return the deduplicated list
return list(unique_domains.values())
except Exception as e:
print(f"Error loading domains from database: {e}")
return []
# Load DNS entries from database - with optional deduplication
def load_dns_entries(specific_upload_id: str = None, deduplicate: bool = False) -> List[Dict]:
try:
entries = dns_records_table.all()
# If a specific upload ID is provided, only show records from that upload
if specific_upload_id:
entries = [e for e in entries if e.get('upload_id') == specific_upload_id]
# Sort by timestamp in descending order (newest first)
entries.sort(key=lambda x: x.get('timestamp', ''), reverse=True)
# If deduplication is requested, only keep the most recent entry for each unique combination
if deduplicate:
# Create a dictionary to track unique entries (most recent only)
unique_entries = {}
for entry in entries:
# Create a unique key based on domain, class, type, TTL, and data
unique_key = f"{entry.get('domain')}:{entry.get('record_class')}:{entry.get('record_type')}:{entry.get('ttl')}:{entry.get('record_data')}"
# Only keep the most recent entry for each unique combination
if unique_key not in unique_entries:
# Mark as most recent entry
entry['is_latest'] = True
unique_entries[unique_key] = entry
# Return the deduplicated list with only the most recent entries
return list(unique_entries.values())
else:
# No deduplication - return all entries
return entries
except Exception as e:
print(f"Error loading DNS records from database: {e}")
return []
# List of known multi-part TLDs
MULTI_PART_TLDS = [
'co.uk', 'org.uk', 'me.uk', 'ac.uk', 'gov.uk', 'net.uk', 'sch.uk',
'com.au', 'net.au', 'org.au', 'edu.au', 'gov.au', 'asn.au', 'id.au',
'co.nz', 'net.nz', 'org.nz', 'govt.nz', 'ac.nz', 'school.nz', 'geek.nz',
'com.sg', 'edu.sg', 'gov.sg', 'net.sg', 'org.sg', 'per.sg',
'co.za', 'org.za', 'web.za', 'net.za', 'gov.za', 'ac.za',
'com.br', 'net.br', 'org.br', 'gov.br', 'edu.br',
'co.jp', 'ac.jp', 'go.jp', 'or.jp', 'ne.jp', 'gr.jp',
'co.in', 'firm.in', 'net.in', 'org.in', 'gen.in', 'ind.in',
'edu.cn', 'gov.cn', 'net.cn', 'org.cn', 'com.cn', 'ac.cn',
'com.mx', 'net.mx', 'org.mx', 'edu.mx', 'gob.mx'
]
# Extract the base domain (SLD+TLD) from a full domain name
def extract_base_domain(domain: str) -> str:
if not domain:
return domain
# Remove trailing dot if present
if domain.endswith('.'):
domain = domain[:-1]
parts = domain.split('.')
# Check if the domain has enough parts
if len(parts) <= 1:
return domain
# Check for known multi-part TLDs first
for tld in MULTI_PART_TLDS:
tld_parts = tld.split('.')
if len(parts) > len(tld_parts) and '.'.join(parts[-len(tld_parts):]) == tld:
# The domain has a multi-part TLD, extract SLD + multi-part TLD
return parts[-len(tld_parts)-1] + '.' + tld
# Default case: extract last two parts
if len(parts) > 1:
return '.'.join(parts[-2:])
return domain
# Get all unique base domains from the database
def get_unique_base_domains(specific_upload_id: str = None) -> List[Dict]:
try:
domains = domains_table.all()
# If a specific upload ID is provided, only show domains from that upload
if specific_upload_id:
domains = [d for d in domains if d.get('upload_id') == specific_upload_id]
# Add the base_domain field to each domain
for domain in domains:
domain['base_domain'] = extract_base_domain(domain.get('full_domain', ''))
# Sort by timestamp in descending order (newest first)
domains.sort(key=lambda x: x.get('timestamp', ''), reverse=True)
# Create dictionaries to track unique base domains
unique_base_domains = {}
# Process each domain and keep only unique base domains
for domain in domains:
base_domain = domain.get('base_domain', '')
# Skip if no base domain
if not base_domain:
continue
# Check if this base domain has been seen before
if base_domain not in unique_base_domains:
# Create a new entry for this base domain - with simplified fields
base_domain_entry = {
'domain': base_domain,
'timestamp': domain.get('timestamp')
}
unique_base_domains[base_domain] = base_domain_entry
# Return the list of unique base domains
return list(unique_base_domains.values())
except Exception as e:
print(f"Error getting unique base domains: {e}")
return []
# Get unique values for filter dropdowns
def get_unique_values(entries: List[Dict]) -> Dict[str, Set]:
unique_values = {
"record_type": set(),
"record_class": set()
}
for entry in entries:
for key in unique_values.keys():
if key in entry and entry[key]:
unique_values[key].add(entry[key])
# Convert sets to sorted lists
return {k: sorted(list(v)) for k, v in unique_values.items()}
# Get all uploads
def get_uploads():
uploads = uploads_table.all()
# Sort uploads by timestamp (newest first)
uploads.sort(key=lambda x: x.get('timestamp', ''), reverse=True)
return uploads
# Delete an upload and its associated data
def delete_upload(upload_id):
try:
# Remove the upload from uploads table
Upload = TinyDBQuery()
uploads_table.remove(Upload.id == upload_id)
# Remove associated domain and DNS records
Domain = TinyDBQuery()
domains_table.remove(Domain.upload_id == upload_id)
DNSRecord = TinyDBQuery()
dns_records_table.remove(DNSRecord.upload_id == upload_id)
return True
except Exception as e:
print(f"Error deleting upload {upload_id}: {e}")
return False
# CSV Export Functions
def domains_to_csv(domains: List[Dict]) -> str:
"""Convert domains data to CSV format"""
csv_output = io.StringIO()
if not domains:
return ""
# Determine fields based on data
# Always include the full_domain field
fields = ["full_domain", "timestamp"]
if "base_domain" in domains[0]:
fields.insert(1, "base_domain")
# Add headers
writer = csv.DictWriter(csv_output, fieldnames=fields, extrasaction='ignore')
writer.writeheader()
# Add data
for domain in domains:
# Create a row dict with formatted timestamp
row = {k: domain.get(k) for k in fields}
if "timestamp" in row and row["timestamp"]:
# Format timestamp nicely for CSV
row["timestamp"] = row["timestamp"].replace('T', ' ').split('.')[0]
writer.writerow(row)
return csv_output.getvalue()
def dns_records_to_csv(records: List[Dict]) -> str:
"""Convert DNS records data to CSV format"""
csv_output = io.StringIO()
if not records:
return ""
# Define the fields to include in the CSV
fields = ["domain", "ttl", "record_class", "record_type", "record_data", "timestamp"]
# Add headers
writer = csv.DictWriter(csv_output, fieldnames=fields, extrasaction='ignore')
writer.writeheader()
# Add data
for record in records:
# Create a row dict with formatted timestamp
row = {k: record.get(k) for k in fields}
if "timestamp" in row and row["timestamp"]:
# Format timestamp nicely for CSV
row["timestamp"] = row["timestamp"].replace('T', ' ').split('.')[0]
writer.writerow(row)
return csv_output.getvalue()
# Routes
@app.get("/", response_class=HTMLResponse)
async def home(
request: Request,
upload_id: Optional[str] = None,
base_domains_only: Optional[bool] = False
):
"""Home page with upload form and domain listing"""
domains = load_domains(upload_id, base_domains_only)
uploads = get_uploads()
return templates.TemplateResponse(
"index.html",
{
"request": request,
"domains": domains,
"uploads": uploads,
"base_domains_only": base_domains_only
}
)
@app.get("/delete-upload/{upload_id}", response_class=RedirectResponse)
async def delete_upload_route(upload_id: str):
"""Delete an upload and all associated records"""
success = delete_upload(upload_id)
if not success:
raise HTTPException(status_code=500, detail="Failed to delete upload")
# Redirect back to home page
return RedirectResponse(url="/", status_code=303)
@app.post("/upload", response_class=RedirectResponse)
async def upload_csv(request: Request, file: UploadFile = File(...), description: str = Form(None)):
"""Handle file upload"""
try:
# Read file content
content = await file.read()
# Ensure content is not empty
if not content or len(content) == 0:
raise ValueError("Uploaded file is empty")
# Generate a unique ID for this upload with timestamp and a random suffix for extra uniqueness
now = datetime.datetime.now()
import random
random_suffix = ''.join(random.choices('abcdefghijklmnopqrstuvwxyz0123456789', k=6))
upload_id = f"upload_{now.strftime('%Y%m%d_%H%M%S')}_{random_suffix}"
print(f"Processing upload: ID={upload_id}, Filename={file.filename}, Content size={len(content)} bytes")
# Process the CSV content
domains_count, records_count = await process_csv_upload(content, upload_id, description)
print(f"Upload processed: {domains_count} domains and {records_count} DNS records inserted")
if domains_count == 0 and records_count == 0:
print("WARNING: No records were inserted. Content may be invalid or empty.")
# Try to decode and print the first few lines for debugging
try:
preview = content.decode('utf-8')[:500]
print(f"File preview: {preview}")
except:
print("Could not decode file content for preview")
# Store upload information with file hash to identify content changes
import hashlib
content_hash = hashlib.md5(content).hexdigest()
upload_info = {
"id": upload_id,
"filename": file.filename,
"description": description,
"timestamp": datetime.datetime.now().isoformat(),
"domains_count": domains_count,
"records_count": records_count,
"content_hash": content_hash
}
uploads_table.insert(upload_info)
# Redirect back to home page
return RedirectResponse(url="/", status_code=303)
except Exception as e:
print(f"Error in upload_csv: {e}")
return {"error": str(e)}
@app.get("/dns-records", response_class=HTMLResponse)
async def dns_records(
request: Request,
upload_id: Optional[str] = None,
record_type: Optional[str] = None,
record_class: Optional[str] = None,
domain: Optional[str] = None,
deduplicate: Optional[bool] = True # Default to showing only unique latest entries
):
"""DNS Records page with filtering"""
# Get all entries first, based on upload_id if provided, with deduplication option
entries = load_dns_entries(upload_id, deduplicate)
# Apply additional filters if provided
if record_type:
entries = [e for e in entries if e.get("record_type") == record_type]
if record_class:
entries = [e for e in entries if e.get("record_class") == record_class]
if domain:
entries = [e for e in entries if domain.lower() in e.get("domain", "").lower()]
# Get unique values for filter dropdowns from all entries (not filtered)
all_entries = load_dns_entries(upload_id, deduplicate=False)
unique_values = get_unique_values(all_entries)
uploads = get_uploads()
return templates.TemplateResponse(
"dns_records.html",
{
"request": request,
"entries": entries,
"unique_values": unique_values,
"uploads": uploads,
"deduplicate": deduplicate
}
)
@app.get("/export-domains-csv")
async def export_domains_csv(
upload_id: Optional[str] = None,
base_domains_only: Optional[bool] = False
):
"""Export domains as CSV"""
domains = load_domains(upload_id, base_domains_only)
csv_content = domains_to_csv(domains)
# Generate a filename with timestamp
filename = f"domains_export_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.csv"
# Return the CSV as a downloadable file
return Response(
content=csv_content,
media_type="text/csv",
headers={"Content-Disposition": f"attachment; filename={filename}"}
)
@app.get("/export-dns-csv")
async def export_dns_csv(
upload_id: Optional[str] = None,
record_type: Optional[str] = None,
record_class: Optional[str] = None,
domain: Optional[str] = None,
deduplicate: Optional[bool] = True
):
"""Export DNS records as CSV"""
# Get entries with applied filters
entries = load_dns_entries(upload_id, deduplicate)
# Apply additional filters if provided
if record_type:
entries = [e for e in entries if e.get("record_type") == record_type]
if record_class:
entries = [e for e in entries if e.get("record_class") == record_class]
if domain:
entries = [e for e in entries if domain.lower() in e.get("domain", "").lower()]
csv_content = dns_records_to_csv(entries)
# Generate a filename with timestamp
filename = f"dns_records_export_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.csv"
# Return the CSV as a downloadable file
return Response(
content=csv_content,
media_type="text/csv",
headers={"Content-Disposition": f"attachment; filename={filename}"}
)
# API Routes
@app.get("/api/uploads", response_model=List[Dict])
async def get_all_uploads():
"""API endpoint that returns all uploads"""
return get_uploads()
@app.get("/api/domains", response_model=List[Dict])
async def get_domains(
upload_id: Optional[str] = None,
base_domains_only: Optional[bool] = False
):
"""API endpoint that returns all domains with optional filtering"""
# The load_domains function handles deduplication and filtering
domains = load_domains(upload_id, base_domains_only)
return domains
@app.get("/api/base-domains", response_model=List[Dict])
async def get_base_domains(upload_id: Optional[str] = None):
"""API endpoint that returns only unique base domains"""
# Get only the unique base domains
base_domains = get_unique_base_domains(upload_id)
return base_domains
@app.get("/api/domains/{domain}", response_model=List[Dict])
async def get_domains_by_name(domain: str, upload_id: Optional[str] = None):
"""API endpoint that returns domains matching a specific domain name with optional filter by upload_id"""
# Get domains, already deduplicated and optionally filtered by upload_id
all_domains = load_domains(upload_id)
# Filter by domain name
filtered = [item for item in all_domains if domain.lower() in item["full_domain"].lower()]
if not filtered:
raise HTTPException(status_code=404, detail=f"No domains found matching: {domain}")
return filtered
@app.get("/api/dns", response_model=List[Dict])
async def get_dns_entries(
record_type: Optional[str] = None,
record_class: Optional[str] = None,
domain: Optional[str] = None,
upload_id: Optional[str] = None,
deduplicate: Optional[bool] = True
):
"""API endpoint that returns filtered DNS entries with optional deduplication"""
# Get entries - if upload_id is specified, only those entries are returned
entries = load_dns_entries(upload_id, deduplicate)
# Apply additional filters if provided
if record_type:
entries = [e for e in entries if e.get("record_type") == record_type]
if record_class:
entries = [e for e in entries if e.get("record_class") == record_class]
if domain:
entries = [e for e in entries if domain.lower() in e.get("domain", "").lower()]
return entries
@app.get("/api/dns/types", response_model=Dict[str, List])
async def get_unique_filter_values(upload_id: Optional[str] = None):
"""API endpoint that returns unique values for filters"""
# Get entries - if upload_id is specified, only those entries are returned
entries = load_dns_entries(upload_id)
return get_unique_values(entries)
# Create templates directory and HTML file
Path("templates").mkdir(exist_ok=True)
if __name__ == "__main__":
uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)